Spaces:
Build error
Build error
import torch | |
from text_to_speech.modules.tts.portaspeech.portaspeech_flow import PortaSpeechFlow | |
from tasks.tts.fs import FastSpeechTask | |
from tasks.tts.ps import PortaSpeechTask | |
from text_to_speech.utils.audio.pitch.utils import denorm_f0 | |
from text_to_speech.utils.commons.hparams import hparams | |
class PortaSpeechFlowTask(PortaSpeechTask): | |
def __init__(self): | |
super().__init__() | |
self.training_post_glow = False | |
def build_tts_model(self): | |
ph_dict_size = len(self.token_encoder) | |
word_dict_size = len(self.word_encoder) | |
self.model = PortaSpeechFlow(ph_dict_size, word_dict_size, hparams) | |
def _training_step(self, sample, batch_idx, opt_idx): | |
self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \ | |
and hparams['use_post_flow'] | |
if hparams['two_stage'] and \ | |
((opt_idx == 0 and self.training_post_glow) or (opt_idx == 1 and not self.training_post_glow)): | |
return None | |
loss_output, _ = self.run_model(sample) | |
total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad]) | |
loss_output['batch_size'] = sample['txt_tokens'].size()[0] | |
if 'postflow' in loss_output and loss_output['postflow'] is None: | |
return None | |
return total_loss, loss_output | |
def run_model(self, sample, infer=False, *args, **kwargs): | |
if not infer: | |
training_post_glow = self.training_post_glow | |
spk_embed = sample.get('spk_embed') | |
spk_id = sample.get('spk_ids') | |
output = self.model(sample['txt_tokens'], | |
sample['word_tokens'], | |
ph2word=sample['ph2word'], | |
mel2word=sample['mel2word'], | |
mel2ph=sample['mel2ph'], | |
word_len=sample['word_lengths'].max(), | |
tgt_mels=sample['mels'], | |
pitch=sample.get('pitch'), | |
spk_embed=spk_embed, | |
spk_id=spk_id, | |
infer=False, | |
forward_post_glow=training_post_glow, | |
two_stage=hparams['two_stage'], | |
global_step=self.global_step, | |
bert_feats=sample.get('bert_feats')) | |
losses = {} | |
self.add_mel_loss(output['mel_out'], sample['mels'], losses) | |
if (training_post_glow or not hparams['two_stage']) and hparams['use_post_flow']: | |
losses['postflow'] = output['postflow'] | |
losses['l1'] = losses['l1'].detach() | |
losses['ssim'] = losses['ssim'].detach() | |
if not training_post_glow or not hparams['two_stage'] or not self.training: | |
losses['kl'] = output['kl'] | |
if self.global_step < hparams['kl_start_steps']: | |
losses['kl'] = losses['kl'].detach() | |
else: | |
losses['kl'] = torch.clamp(losses['kl'], min=hparams['kl_min']) | |
losses['kl'] = losses['kl'] * hparams['lambda_kl'] | |
if hparams['dur_level'] == 'word': | |
self.add_dur_loss( | |
output['dur'], sample['mel2word'], sample['word_lengths'], sample['txt_tokens'], losses) | |
self.get_attn_stats(output['attn'], sample, losses) | |
else: | |
super().add_dur_loss(output['dur'], sample['mel2ph'], sample['txt_tokens'], losses) | |
return losses, output | |
else: | |
use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur']) | |
forward_post_glow = self.global_step >= hparams['post_glow_training_start'] + 1000 \ | |
and hparams['use_post_flow'] | |
spk_embed = sample.get('spk_embed') | |
spk_id = sample.get('spk_ids') | |
output = self.model( | |
sample['txt_tokens'], | |
sample['word_tokens'], | |
ph2word=sample['ph2word'], | |
word_len=sample['word_lengths'].max(), | |
pitch=sample.get('pitch'), | |
mel2ph=sample['mel2ph'] if use_gt_dur else None, | |
mel2word=sample['mel2word'] if hparams['profile_infer'] or hparams['use_gt_dur'] else None, | |
infer=True, | |
forward_post_glow=forward_post_glow, | |
spk_embed=spk_embed, | |
spk_id=spk_id, | |
two_stage=hparams['two_stage'], | |
bert_feats=sample.get('bert_feats')) | |
return output | |
def validation_step(self, sample, batch_idx): | |
self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \ | |
and hparams['use_post_flow'] | |
return super().validation_step(sample, batch_idx) | |
def save_valid_result(self, sample, batch_idx, model_out): | |
super(PortaSpeechFlowTask, self).save_valid_result(sample, batch_idx, model_out) | |
sr = hparams['audio_sample_rate'] | |
f0_gt = None | |
if sample.get('f0') is not None: | |
f0_gt = denorm_f0(sample['f0'][0].cpu(), sample['uv'][0].cpu()) | |
if self.global_step > 0: | |
# save FVAE result | |
if hparams['use_post_flow']: | |
wav_pred = self.vocoder.spec2wav(model_out['mel_out_fvae'][0].cpu(), f0=f0_gt) | |
self.logger.add_audio(f'wav_fvae_{batch_idx}', wav_pred, self.global_step, sr) | |
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out_fvae'][0], | |
f'mel_fvae_{batch_idx}', f0s=f0_gt) | |
def build_optimizer(self, model): | |
if hparams['two_stage'] and hparams['use_post_flow']: | |
self.optimizer = torch.optim.AdamW( | |
[p for name, p in self.model.named_parameters() if 'post_flow' not in name], | |
lr=hparams['lr'], | |
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
weight_decay=hparams['weight_decay']) | |
self.post_flow_optimizer = torch.optim.AdamW( | |
self.model.post_flow.parameters(), | |
lr=hparams['post_flow_lr'], | |
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
weight_decay=hparams['weight_decay']) | |
return [self.optimizer, self.post_flow_optimizer] | |
else: | |
self.optimizer = torch.optim.AdamW( | |
self.model.parameters(), | |
lr=hparams['lr'], | |
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
weight_decay=hparams['weight_decay']) | |
return [self.optimizer] | |
def build_scheduler(self, optimizer): | |
return FastSpeechTask.build_scheduler(self, optimizer[0]) |