#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Sep 20 16:56:19 2018 @author: harry """ import librosa import numpy as np import torch import torch.autograd as grad import torch.nn.functional as F from utils.hparam import hparam as hp def get_centroids_prior(embeddings): centroids = [] for speaker in embeddings: centroid = 0 for utterance in speaker: centroid = centroid + utterance centroid = centroid/len(speaker) centroids.append(centroid) centroids = torch.stack(centroids) return centroids def get_centroids(embeddings): centroids = embeddings.mean(dim=1) return centroids def get_centroid(embeddings, speaker_num, utterance_num): centroid = 0 for utterance_id, utterance in enumerate(embeddings[speaker_num]): if utterance_id == utterance_num: continue centroid = centroid + utterance centroid = centroid/(len(embeddings[speaker_num])-1) return centroid def get_utterance_centroids(embeddings): """ Returns the centroids for each utterance of a speaker, where the utterance centroid is the speaker centroid without considering this utterance Shape of embeddings should be: (speaker_ct, utterance_per_speaker_ct, embedding_size) """ sum_centroids = embeddings.sum(dim=1) # we want to subtract out each utterance, prior to calculating the # the utterance centroid sum_centroids = sum_centroids.reshape( sum_centroids.shape[0], 1, sum_centroids.shape[-1] ) # we want the mean but not including the utterance itself, so -1 num_utterances = embeddings.shape[1] - 1 centroids = (sum_centroids - embeddings) / num_utterances return centroids def get_cossim_prior(embeddings, centroids): # Calculates cosine similarity matrix. Requires (N, M, feature) input cossim = torch.zeros(embeddings.size(0),embeddings.size(1),centroids.size(0)) for speaker_num, speaker in enumerate(embeddings): for utterance_num, utterance in enumerate(speaker): for centroid_num, centroid in enumerate(centroids): if speaker_num == centroid_num: centroid = get_centroid(embeddings, speaker_num, utterance_num) output = F.cosine_similarity(utterance,centroid,dim=0)+1e-6 cossim[speaker_num][utterance_num][centroid_num] = output return cossim def get_cossim(embeddings, centroids): # number of utterances per speaker num_utterances = embeddings.shape[1] utterance_centroids = get_utterance_centroids(embeddings) # flatten the embeddings and utterance centroids to just utterance, # so we can do cosine similarity utterance_centroids_flat = utterance_centroids.view( utterance_centroids.shape[0] * utterance_centroids.shape[1], -1 ) embeddings_flat = embeddings.view( embeddings.shape[0] * num_utterances, -1 ) # the cosine distance between utterance and the associated centroids # for that utterance # this is each speaker's utterances against his own centroid, but each # comparison centroid has the current utterance removed cos_same = F.cosine_similarity(embeddings_flat, utterance_centroids_flat) # now we get the cosine distance between each utterance and the other speakers' # centroids # to do so requires comparing each utterance to each centroid. To keep the # operation fast, we vectorize by using matrices L (embeddings) and # R (centroids) where L has each utterance repeated sequentially for all # comparisons and R has the entire centroids frame repeated for each utterance centroids_expand = centroids.repeat((num_utterances * embeddings.shape[0], 1)) embeddings_expand = embeddings_flat.unsqueeze(1).repeat(1, embeddings.shape[0], 1) embeddings_expand = embeddings_expand.view( embeddings_expand.shape[0] * embeddings_expand.shape[1], embeddings_expand.shape[-1] ) cos_diff = F.cosine_similarity(embeddings_expand, centroids_expand) cos_diff = cos_diff.view( embeddings.size(0), num_utterances, centroids.size(0) ) # assign the cosine distance for same speakers to the proper idx same_idx = list(range(embeddings.size(0))) cos_diff[same_idx, :, same_idx] = cos_same.view(embeddings.shape[0], num_utterances) cos_diff = cos_diff + 1e-6 return cos_diff def calc_loss_prior(sim_matrix): # Calculates loss from (N, M, K) similarity matrix per_embedding_loss = torch.zeros(sim_matrix.size(0), sim_matrix.size(1)) for j in range(len(sim_matrix)): for i in range(sim_matrix.size(1)): per_embedding_loss[j][i] = -(sim_matrix[j][i][j] - ((torch.exp(sim_matrix[j][i]).sum()+1e-6).log_())) loss = per_embedding_loss.sum() return loss, per_embedding_loss def calc_loss(sim_matrix): same_idx = list(range(sim_matrix.size(0))) pos = sim_matrix[same_idx, :, same_idx] neg = (torch.exp(sim_matrix).sum(dim=2) + 1e-6).log_() per_embedding_loss = -1 * (pos - neg) loss = per_embedding_loss.sum() return loss, per_embedding_loss def normalize_0_1(values, max_value, min_value): normalized = np.clip((values - min_value) / (max_value - min_value), 0, 1) return normalized def mfccs_and_spec(wav_file, wav_process = False, calc_mfccs=False, calc_mag_db=False): sound_file, _ = librosa.core.load(wav_file, sr=hp.data.sr) window_length = int(hp.data.window*hp.data.sr) hop_length = int(hp.data.hop*hp.data.sr) duration = hp.data.tisv_frame * hp.data.hop + hp.data.window # Cut silence and fix length if wav_process == True: sound_file, index = librosa.effects.trim(sound_file, frame_length=window_length, hop_length=hop_length) length = int(hp.data.sr * duration) sound_file = librosa.util.fix_length(sound_file, length) spec = librosa.stft(sound_file, n_fft=hp.data.nfft, hop_length=hop_length, win_length=window_length) mag_spec = np.abs(spec) mel_basis = librosa.filters.mel(hp.data.sr, hp.data.nfft, n_mels=hp.data.nmels) mel_spec = np.dot(mel_basis, mag_spec) mag_db = librosa.amplitude_to_db(mag_spec) #db mel spectrogram mel_db = librosa.amplitude_to_db(mel_spec).T mfccs = None if calc_mfccs: mfccs = np.dot(librosa.filters.dct(40, mel_db.shape[0]), mel_db).T return mfccs, mel_db, mag_db if __name__ == "__main__": w = grad.Variable(torch.tensor(1.0)) b = grad.Variable(torch.tensor(0.0)) embeddings = torch.tensor([[0,1,0],[0,0,1], [0,1,0], [0,1,0], [1,0,0], [1,0,0]]).to(torch.float).reshape(3,2,3) centroids = get_centroids(embeddings) cossim = get_cossim(embeddings, centroids) sim_matrix = w*cossim + b loss, per_embedding_loss = calc_loss(sim_matrix)