|
import torch |
|
|
|
import utils |
|
from utils.hparams import hparams |
|
from network.diff.net import DiffNet |
|
from network.diff.diffusion import GaussianDiffusion, OfflineGaussianDiffusion |
|
from training.task.fs2 import FastSpeech2Task |
|
from network.vocoders.base_vocoder import get_vocoder_cls, BaseVocoder |
|
from modules.fastspeech.tts_modules import mel2ph_to_dur |
|
|
|
from network.diff.candidate_decoder import FFT |
|
from utils.pitch_utils import denorm_f0 |
|
from training.dataset.fs2_utils import FastSpeechDataset |
|
|
|
import numpy as np |
|
import os |
|
import torch.nn.functional as F |
|
|
|
DIFF_DECODERS = { |
|
'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']), |
|
'fft': lambda hp: FFT( |
|
hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']), |
|
} |
|
|
|
|
|
class SVCDataset(FastSpeechDataset): |
|
def collater(self, samples): |
|
from preprocessing.process_pipeline import File2Batch |
|
return File2Batch.processed_input2batch(samples) |
|
|
|
|
|
class SVCTask(FastSpeech2Task): |
|
def __init__(self): |
|
super(SVCTask, self).__init__() |
|
self.dataset_cls = SVCDataset |
|
self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() |
|
|
|
def build_tts_model(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mel_bins = hparams['audio_num_mel_bins'] |
|
self.model = GaussianDiffusion( |
|
phone_encoder=self.phone_encoder, |
|
out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), |
|
timesteps=hparams['timesteps'], |
|
K_step=hparams['K_step'], |
|
loss_type=hparams['diff_loss_type'], |
|
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], |
|
) |
|
|
|
|
|
def build_optimizer(self, model): |
|
self.optimizer = optimizer = torch.optim.AdamW( |
|
filter(lambda p: p.requires_grad, model.parameters()), |
|
lr=hparams['lr'], |
|
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), |
|
weight_decay=hparams['weight_decay']) |
|
return optimizer |
|
|
|
def run_model(self, model, sample, return_output=False, infer=False): |
|
''' |
|
steps: |
|
1. run the full model, calc the main loss |
|
2. calculate loss for dur_predictor, pitch_predictor, energy_predictor |
|
''' |
|
hubert = sample['hubert'] |
|
target = sample['mels'] |
|
mel2ph = sample['mel2ph'] |
|
f0 = sample['f0'] |
|
uv = sample['uv'] |
|
energy = sample['energy'] |
|
|
|
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') |
|
if hparams['pitch_type'] == 'cwt': |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = model(hubert, mel2ph=mel2ph, spk_embed=spk_embed, |
|
ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer) |
|
|
|
losses = {} |
|
if 'diff_loss' in output: |
|
losses['mel'] = output['diff_loss'] |
|
|
|
|
|
|
|
|
|
|
|
if not return_output: |
|
return losses |
|
else: |
|
return losses, output |
|
|
|
def _training_step(self, sample, batch_idx, _): |
|
log_outputs = self.run_model(self.model, sample) |
|
total_loss = sum([v for v in log_outputs.values() if isinstance(v, torch.Tensor) and v.requires_grad]) |
|
log_outputs['batch_size'] = sample['hubert'].size()[0] |
|
log_outputs['lr'] = self.scheduler.get_lr()[0] |
|
return total_loss, log_outputs |
|
|
|
def build_scheduler(self, optimizer): |
|
return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5) |
|
|
|
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx): |
|
if optimizer is None: |
|
return |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
if self.scheduler is not None: |
|
self.scheduler.step(self.global_step // hparams['accumulate_grad_batches']) |
|
|
|
def validation_step(self, sample, batch_idx): |
|
outputs = {} |
|
hubert = sample['hubert'] |
|
|
|
target = sample['mels'] |
|
energy = sample['energy'] |
|
|
|
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') |
|
mel2ph = sample['mel2ph'] |
|
|
|
outputs['losses'] = {} |
|
|
|
outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) |
|
|
|
outputs['total_loss'] = sum(outputs['losses'].values()) |
|
outputs['nsamples'] = sample['nsamples'] |
|
outputs = utils.tensors_to_scalars(outputs) |
|
if batch_idx < hparams['num_valid_plots']: |
|
model_out = self.model( |
|
hubert, spk_embed=spk_embed, mel2ph=mel2ph, f0=sample['f0'], uv=sample['uv'], energy=energy, ref_mels=None, infer=True |
|
) |
|
|
|
if hparams.get('pe_enable') is not None and hparams['pe_enable']: |
|
gt_f0 = self.pe(sample['mels'])['f0_denorm_pred'] |
|
pred_f0 = self.pe(model_out['mel_out'])['f0_denorm_pred'] |
|
else: |
|
gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) |
|
pred_f0 = model_out.get('f0_denorm') |
|
self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=pred_f0) |
|
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'], name=f'diffmel_{batch_idx}') |
|
|
|
if hparams['use_pitch_embed']: |
|
self.plot_pitch(batch_idx, sample, model_out) |
|
return outputs |
|
|
|
def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, wdb, losses=None): |
|
""" |
|
the effect of each loss component: |
|
hparams['dur_loss'] : align each phoneme |
|
hparams['lambda_word_dur']: align each word |
|
hparams['lambda_sent_dur']: align each sentence |
|
|
|
:param dur_pred: [B, T], float, log scale |
|
:param mel2ph: [B, T] |
|
:param txt_tokens: [B, T] |
|
:param losses: |
|
:return: |
|
""" |
|
B, T = txt_tokens.shape |
|
nonpadding = (txt_tokens != 0).float() |
|
dur_gt = mel2ph_to_dur(mel2ph, T).float() * nonpadding |
|
is_sil = torch.zeros_like(txt_tokens).bool() |
|
for p in self.sil_ph: |
|
is_sil = is_sil | (txt_tokens == self.phone_encoder.encode(p)[0]) |
|
is_sil = is_sil.float() |
|
|
|
|
|
if hparams['dur_loss'] == 'mse': |
|
losses['pdur'] = F.mse_loss(dur_pred, (dur_gt + 1).log(), reduction='none') |
|
losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum() |
|
losses['pdur'] = losses['pdur'] * hparams['lambda_ph_dur'] |
|
dur_pred = (dur_pred.exp() - 1).clamp(min=0) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
if hparams['lambda_word_dur'] > 0: |
|
|
|
idx = wdb.cumsum(axis=1) |
|
|
|
word_dur_p = dur_pred.new_zeros([B, idx.max() + 1]).scatter_add(1, idx, dur_pred) |
|
word_dur_g = dur_gt.new_zeros([B, idx.max() + 1]).scatter_add(1, idx, dur_gt) |
|
wdur_loss = F.mse_loss((word_dur_p + 1).log(), (word_dur_g + 1).log(), reduction='none') |
|
word_nonpadding = (word_dur_g > 0).float() |
|
wdur_loss = (wdur_loss * word_nonpadding).sum() / word_nonpadding.sum() |
|
losses['wdur'] = wdur_loss * hparams['lambda_word_dur'] |
|
if hparams['lambda_sent_dur'] > 0: |
|
sent_dur_p = dur_pred.sum(-1) |
|
sent_dur_g = dur_gt.sum(-1) |
|
sdur_loss = F.mse_loss((sent_dur_p + 1).log(), (sent_dur_g + 1).log(), reduction='mean') |
|
losses['sdur'] = sdur_loss.mean() * hparams['lambda_sent_dur'] |
|
|
|
|
|
|
|
|
|
def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None): |
|
gt_wav = gt_wav[0].cpu().numpy() |
|
wav_out = wav_out[0].cpu().numpy() |
|
gt_f0 = gt_f0[0].cpu().numpy() |
|
f0 = f0[0].cpu().numpy() |
|
if is_mel: |
|
gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0) |
|
wav_out = self.vocoder.spec2wav(wav_out, f0=f0) |
|
self.logger.experiment.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) |
|
self.logger.experiment.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) |
|
|
|
|
|
|