from src.music.utilities.representation_learning_utilities.constants import * from src.music.config import REP_MODEL_NAME from src.music.utils import get_out_path import pickle import numpy as np # from transformers import AutoModel, AutoTokenizer from torch import nn from src.music.representation_learning.sentence_transfo.sentence_transformers import SentenceTransformer class Argument(object): def __init__(self, adict): self.__dict__.update(adict) class RepModel(nn.Module): def __init__(self, model, model_name): super().__init__() if 't5' in model_name: self.model = model.get_encoder() else: self.model = model self.model.eval() def forward(self, inputs): with torch.no_grad(): out = self.model(inputs, output_hidden_states=True) embeddings = out.hidden_states[-1] return torch.mean(embeddings[0], dim=0) # def get_trained_music_LM(model_name): # tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) # model = RepModel(AutoModel.from_pretrained(model_name, use_auth_token=True), model_name) # # return model, tokenizer def get_trained_sentence_embedder(model_name): model = SentenceTransformer(model_name) return model MODEL = get_trained_sentence_embedder(REP_MODEL_NAME) def encoded2rep(encoded_path, rep_path=None, return_rep=False, verbose=False, level=0): if not rep_path: rep_path, _, _ = get_out_path(in_path=encoded_path, in_word='encoded', out_word='represented', out_extension='.txt') error_msg = 'Error in music transformer mapping.' if verbose: print(' ' * level + 'Mapping to final music representations') try: error_msg += ' Error in encoded file loading?' with open(encoded_path, 'rb') as f: data = pickle.load(f) performance = [str(w) for w in data['main'] if w != 1] assert len(performance) % 5 == 0 if(len(performance) == 0): error_msg += " Error: No midi messages in primer file" assert False error_msg += ' Nope, error in tokenization?' perf = ' '.join(performance) # tokenized = torch.IntTensor(TOKENIZER.encode(perf)).unsqueeze(dim=0) error_msg += ' Nope. Maybe in performance encoding?' # reps = [] # for i_chunk in range(min(tokenized.shape[1] // 510 - 1, 8)): # chunk_tokenized = tokenized[:, i_chunk * 510: (i_chunk + 1) * 510 + 2] # rep = MODEL(chunk_tokenized) # reps.append(rep.detach().numpy()) # representation = np.mean(reps, axis=0) representation = MODEL.encode(perf) error_msg += ' Nope. Saving performance?' np.savetxt(rep_path, representation) error_msg += ' Nope.' if verbose: print(' ' * (level + 2) + 'Success.') if return_rep: return rep_path, representation, '' else: return rep_path, '' except: if verbose: print(' ' * (level + 2) + f'Failed with error: {error_msg}') if return_rep: return None, None, error_msg else: return None, error_msg if __name__ == "__main__": representation = encoded2rep("/home/cedric/Documents/pianocktail/data/music/encoded/single_videos_midi_processed_encoded/chris_dawson_all_of_me_.pickle") stop = 1