Spaces:
Runtime error
Runtime error
import os.path | |
# from src.music.data_collection.is_audio_solo_piano import calculate_piano_solo_prob | |
from src.music.utils import load_audio | |
from src.music.config import FPS | |
import pretty_midi as pm | |
import numpy as np | |
from src.music.config import MUSIC_REP_PATH, MUSIC_NN_PATH | |
from sklearn.neighbors import NearestNeighbors | |
from src.cocktails.config import FULL_COCKTAIL_REP_PATH, COCKTAIL_NN_PATH, COCKTAILS_CSV_DATA | |
# from src.cocktails.pipeline.get_affect2affective_cluster import get_affective_cluster_centers | |
from src.cocktails.utilities.other_scrubbing_utilities import print_recipe | |
from src.music.utils import get_all_subfiles_with_extension | |
import os | |
import pickle | |
import pandas as pd | |
import time | |
keyword = 'b256_r128_represented' | |
def load_reps(rep_path, sample_size=None): | |
if sample_size: | |
with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f: | |
data = pickle.load(f) | |
else: | |
with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f: | |
data = pickle.load(f) | |
reps = data['reps'] | |
# playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']] | |
playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']] | |
n_data, dim_data = reps.shape | |
return reps, data['paths'], playlists, n_data, dim_data | |
class Debugger(): | |
def __init__(self, verbose=True): | |
if verbose: print('Setting up debugger.') | |
if not os.path.exists(MUSIC_NN_PATH): | |
reps_path = MUSIC_REP_PATH + 'music_reps_unnormalized.pickle' | |
if not os.path.exists(reps_path): | |
all_rep_path = get_all_subfiles_with_extension(MUSIC_REP_PATH, max_depth=3, extension='.txt', current_depth=0) | |
all_data = [] | |
new_all_rep_path = [] | |
for i_r, r in enumerate(all_rep_path): | |
if 'mean_std' not in r: | |
all_data.append(np.loadtxt(r)) | |
assert len(all_data[-1]) == 128 | |
new_all_rep_path.append(r) | |
data = np.array(all_data) | |
to_save = dict(reps=data, | |
paths=new_all_rep_path) | |
with open(reps_path, 'wb') as f: | |
pickle.dump(to_save, f) | |
reps, self.rep_paths, playlists, n_data, self.dim_rep_music = load_reps(MUSIC_REP_PATH) | |
self.nn_model_music = NearestNeighbors(n_neighbors=6, metric='cosine') | |
self.nn_model_music.fit(reps) | |
to_save = dict(nn_model=self.nn_model_music, | |
rep_paths=self.rep_paths, | |
dim_rep_music=self.dim_rep_music) | |
with open(MUSIC_NN_PATH, 'wb') as f: | |
pickle.dump(to_save, f) | |
else: | |
with open(MUSIC_NN_PATH, 'rb') as f: | |
data = pickle.load(f) | |
self.nn_model_music = data['nn_model'] | |
self.rep_paths = data['rep_paths'] | |
self.dim_rep_music = data['dim_rep_music'] | |
if verbose: print(f' {len(self.rep_paths)} songs, representation dim: {self.dim_rep_music}') | |
self.rep_paths = np.array(self.rep_paths) | |
if not os.path.exists(COCKTAIL_NN_PATH): | |
cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH) | |
# cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0) | |
self.nn_model_cocktail = NearestNeighbors(n_neighbors=6) | |
self.nn_model_cocktail.fit(cocktail_reps) | |
self.dim_rep_cocktail = cocktail_reps.shape[1] | |
self.n_cocktails = cocktail_reps.shape[0] | |
to_save = dict(nn_model=self.nn_model_cocktail, | |
dim_rep_cocktail=self.dim_rep_cocktail, | |
n_cocktails=self.n_cocktails) | |
with open(COCKTAIL_NN_PATH, 'wb') as f: | |
pickle.dump(to_save, f) | |
else: | |
with open(COCKTAIL_NN_PATH, 'rb') as f: | |
data = pickle.load(f) | |
self.nn_model_cocktail = data['nn_model'] | |
self.dim_rep_cocktail = data['dim_rep_cocktail'] | |
self.n_cocktails = data['n_cocktails'] | |
if verbose: print(f' {self.n_cocktails} cocktails, representation dim: {self.dim_rep_cocktail}') | |
self.cocktail_data = pd.read_csv(COCKTAILS_CSV_DATA) | |
# self.affective_cluster_centers = get_affective_cluster_centers() | |
self.keys_to_print = ['mse_reconstruction', 'nearest_cocktail_recipes', 'nearest_cocktail_urls', | |
'nn_music_dists', 'nn_music', 'dim_rep', 'nb_notes', 'audio_len', 'piano_solo_prob', 'recipe_score', 'cocktail_rep'] | |
# 'affect', 'affective_cluster_id', 'affective_cluster_center', | |
def get_nearest_songs(self, music_rep): | |
dists, indexes = self.nn_model_music.kneighbors(music_rep.reshape(1, -1)) | |
indexes = indexes.flatten()[:5] | |
rep_paths = [r.split('/')[-1] for r in self.rep_paths[indexes[:5]]] | |
return rep_paths, dists.flatten().tolist() | |
def get_nearest_cocktails(self, cocktail_rep): | |
dists, indexes = self.nn_model_cocktail.kneighbors(cocktail_rep.reshape(1, -1)) | |
indexes = indexes.flatten() | |
nn_names = np.array(self.cocktail_data['names'])[indexes].tolist() | |
nn_urls = np.array(self.cocktail_data['urls'])[indexes].tolist() | |
nn_recipes = [print_recipe(ingredient_str=ing_str, to_print=False) for ing_str in np.array(self.cocktail_data['ingredients_str'])[indexes]] | |
nn_ing_strs = np.array(self.cocktail_data['ingredients_str'])[indexes].tolist() | |
return indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs | |
def extract_info(self, all_paths, affective_cluster_id, affect, cocktail_rep, music_reconstruction, recipe_score, verbose=False, level=0): | |
if verbose: print(' ' * level + 'Extracting debug info..') | |
init_time = time.time() | |
debug_dict = dict() | |
debug_dict['all_paths'] = all_paths | |
debug_dict['recipe_score'] = recipe_score | |
if all_paths['audio_path'] != None: | |
# is it piano? | |
debug_dict['piano_solo_prob'] = None#float(calculate_piano_solo_prob(all_paths['audio_path'])[0]) | |
# how long is the audio | |
(audio, _) = load_audio(all_paths['audio_path'], sr=FPS, mono=True) | |
debug_dict['audio_len'] = int(len(audio) / FPS) | |
else: | |
debug_dict['piano_solo_prob'] = None | |
debug_dict['audio_len'] = None | |
# how many notes? | |
midi = pm.PrettyMIDI(all_paths['processed_path']) | |
debug_dict['nb_notes'] = len(midi.instruments[0].notes) | |
# dimension of music rep | |
representation = np.loadtxt(all_paths['representation_path']) | |
debug_dict['dim_rep'] = representation.shape[0] | |
# closest songs in dataset | |
debug_dict['nn_music'], debug_dict['nn_music_dists'] = self.get_nearest_songs(representation) | |
# get affective cluster info | |
# debug_dict['affective_cluster_id'] = affective_cluster_id[0] | |
# debug_dict['affective_cluster_center'] = self.affective_cluster_centers[affective_cluster_id].flatten().tolist() | |
# debug_dict['affect'] = affect.flatten().tolist() | |
indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs = self.get_nearest_cocktails(cocktail_rep) | |
debug_dict['cocktail_rep'] = cocktail_rep.copy().tolist() | |
debug_dict['nearest_cocktail_indexes'] = indexes.tolist() | |
debug_dict['nn_ing_strs'] = nn_ing_strs | |
debug_dict['nearest_cocktail_names'] = nn_names | |
debug_dict['nearest_cocktail_urls'] = nn_urls | |
debug_dict['nearest_cocktail_recipes'] = nn_recipes | |
debug_dict['music_reconstruction'] = music_reconstruction.tolist() | |
debug_dict['mse_reconstruction'] = ((music_reconstruction - representation) ** 2).mean() | |
self.debug_dict = debug_dict | |
if verbose: print(' ' * (level + 2) + f'Debug info extracted in {int(time.time() - init_time)} seconds.') | |
return self.debug_dict | |
def print_debug(self, level=0): | |
print(' ' * level + '__DEBUGGING INFO__') | |
for k in self.keys_to_print: | |
to_print = self.debug_dict[k] | |
if k == 'nearest_cocktail_recipes': | |
to_print = self.debug_dict[k].copy() | |
for i in range(len(to_print)): | |
to_print[i] = to_print[i].replace('\n', '').replace('\t', '').replace('()', '') | |
if k == "nn_music": | |
to_print = self.debug_dict[k].copy() | |
for i in range(len(to_print)): | |
to_print[i] = to_print[i].replace('encoded_new_structured_', '').replace('_represented.txt', '') | |
to_print_str = f'{to_print}' | |
if isinstance(to_print, float): | |
to_print_str = f'{to_print:.2f}' | |
elif isinstance(to_print, list): | |
if isinstance(to_print[0], float): | |
to_print_str = '[' | |
for element in to_print: | |
to_print_str += f'{element:.2f}, ' | |
to_print_str = to_print_str[:-2] + ']' | |
print(' ' * (level + 2) + f'{k} : ' + to_print_str) |