Spaces:
Runtime error
Runtime error
from src.music.utilities.representation_learning_utilities.constants import * | |
from src.music.config import REP_MODEL_NAME | |
from src.music.utils import get_out_path | |
import pickle | |
import numpy as np | |
# from transformers import AutoModel, AutoTokenizer | |
from torch import nn | |
from src.music.representation_learning.sentence_transfo.sentence_transformers import SentenceTransformer | |
class Argument(object): | |
def __init__(self, adict): | |
self.__dict__.update(adict) | |
class RepModel(nn.Module): | |
def __init__(self, model, model_name): | |
super().__init__() | |
if 't5' in model_name: | |
self.model = model.get_encoder() | |
else: | |
self.model = model | |
self.model.eval() | |
def forward(self, inputs): | |
with torch.no_grad(): | |
out = self.model(inputs, output_hidden_states=True) | |
embeddings = out.hidden_states[-1] | |
return torch.mean(embeddings[0], dim=0) | |
# def get_trained_music_LM(model_name): | |
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) | |
# model = RepModel(AutoModel.from_pretrained(model_name, use_auth_token=True), model_name) | |
# | |
# return model, tokenizer | |
def get_trained_sentence_embedder(model_name): | |
model = SentenceTransformer(model_name) | |
return model | |
MODEL = get_trained_sentence_embedder(REP_MODEL_NAME) | |
def encoded2rep(encoded_path, rep_path=None, return_rep=False, verbose=False, level=0): | |
if not rep_path: | |
rep_path, _, _ = get_out_path(in_path=encoded_path, in_word='encoded', out_word='represented', out_extension='.txt') | |
error_msg = 'Error in music transformer mapping.' | |
if verbose: print(' ' * level + 'Mapping to final music representations') | |
try: | |
error_msg += ' Error in encoded file loading?' | |
with open(encoded_path, 'rb') as f: | |
data = pickle.load(f) | |
performance = [str(w) for w in data['main'] if w != 1] | |
assert len(performance) % 5 == 0 | |
if(len(performance) == 0): | |
error_msg += " Error: No midi messages in primer file" | |
assert False | |
error_msg += ' Nope, error in tokenization?' | |
perf = ' '.join(performance) | |
# tokenized = torch.IntTensor(TOKENIZER.encode(perf)).unsqueeze(dim=0) | |
error_msg += ' Nope. Maybe in performance encoding?' | |
# reps = [] | |
# for i_chunk in range(min(tokenized.shape[1] // 510 - 1, 8)): | |
# chunk_tokenized = tokenized[:, i_chunk * 510: (i_chunk + 1) * 510 + 2] | |
# rep = MODEL(chunk_tokenized) | |
# reps.append(rep.detach().numpy()) | |
# representation = np.mean(reps, axis=0) | |
representation = MODEL.encode(perf) | |
error_msg += ' Nope. Saving performance?' | |
np.savetxt(rep_path, representation) | |
error_msg += ' Nope.' | |
if verbose: print(' ' * (level + 2) + 'Success.') | |
if return_rep: | |
return rep_path, representation, '' | |
else: | |
return rep_path, '' | |
except: | |
if verbose: print(' ' * (level + 2) + f'Failed with error: {error_msg}') | |
if return_rep: | |
return None, None, error_msg | |
else: | |
return None, error_msg | |
if __name__ == "__main__": | |
representation = encoded2rep("/home/cedric/Documents/pianocktail/data/music/encoded/single_videos_midi_processed_encoded/chris_dawson_all_of_me_.pickle") | |
stop = 1 | |