File size: 7,014 Bytes
9206300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
from text_to_speech.modules.tts.portaspeech.portaspeech_flow import PortaSpeechFlow
from tasks.tts.fs import FastSpeechTask
from tasks.tts.ps import PortaSpeechTask
from text_to_speech.utils.audio.pitch.utils import denorm_f0
from text_to_speech.utils.commons.hparams import hparams


class PortaSpeechFlowTask(PortaSpeechTask):
    def __init__(self):
        super().__init__()
        self.training_post_glow = False

    def build_tts_model(self):
        ph_dict_size = len(self.token_encoder)
        word_dict_size = len(self.word_encoder)
        self.model = PortaSpeechFlow(ph_dict_size, word_dict_size, hparams)

    def _training_step(self, sample, batch_idx, opt_idx):
        self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \
                                  and hparams['use_post_flow']
        if hparams['two_stage'] and \
                ((opt_idx == 0 and self.training_post_glow) or (opt_idx == 1 and not self.training_post_glow)):
            return None
        loss_output, _ = self.run_model(sample)
        total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
        loss_output['batch_size'] = sample['txt_tokens'].size()[0]
        if 'postflow' in loss_output and loss_output['postflow'] is None:
            return None
        return total_loss, loss_output

    def run_model(self, sample, infer=False, *args, **kwargs):
        if not infer:
            training_post_glow = self.training_post_glow
            spk_embed = sample.get('spk_embed')
            spk_id = sample.get('spk_ids')
            output = self.model(sample['txt_tokens'],
                                sample['word_tokens'],
                                ph2word=sample['ph2word'],
                                mel2word=sample['mel2word'],
                                mel2ph=sample['mel2ph'],
                                word_len=sample['word_lengths'].max(),
                                tgt_mels=sample['mels'],
                                pitch=sample.get('pitch'),
                                spk_embed=spk_embed,
                                spk_id=spk_id,
                                infer=False,
                                forward_post_glow=training_post_glow,
                                two_stage=hparams['two_stage'],
                                global_step=self.global_step,
                                bert_feats=sample.get('bert_feats'))
            losses = {}
            self.add_mel_loss(output['mel_out'], sample['mels'], losses)
            if (training_post_glow or not hparams['two_stage']) and hparams['use_post_flow']:
                losses['postflow'] = output['postflow']
                losses['l1'] = losses['l1'].detach()
                losses['ssim'] = losses['ssim'].detach()
            if not training_post_glow or not hparams['two_stage'] or not self.training:
                losses['kl'] = output['kl']
                if self.global_step < hparams['kl_start_steps']:
                    losses['kl'] = losses['kl'].detach()
                else:
                    losses['kl'] = torch.clamp(losses['kl'], min=hparams['kl_min'])
                losses['kl'] = losses['kl'] * hparams['lambda_kl']
                if hparams['dur_level'] == 'word':
                    self.add_dur_loss(
                        output['dur'], sample['mel2word'], sample['word_lengths'], sample['txt_tokens'], losses)
                    self.get_attn_stats(output['attn'], sample, losses)
                else:
                    super().add_dur_loss(output['dur'], sample['mel2ph'], sample['txt_tokens'], losses)
            return losses, output
        else:
            use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur'])
            forward_post_glow = self.global_step >= hparams['post_glow_training_start'] + 1000 \
                                and hparams['use_post_flow']
            spk_embed = sample.get('spk_embed')
            spk_id = sample.get('spk_ids')
            output = self.model(
                sample['txt_tokens'],
                sample['word_tokens'],
                ph2word=sample['ph2word'],
                word_len=sample['word_lengths'].max(),
                pitch=sample.get('pitch'),
                mel2ph=sample['mel2ph'] if use_gt_dur else None,
                mel2word=sample['mel2word'] if hparams['profile_infer'] or hparams['use_gt_dur'] else None,
                infer=True,
                forward_post_glow=forward_post_glow,
                spk_embed=spk_embed,
                spk_id=spk_id,
                two_stage=hparams['two_stage'],
                bert_feats=sample.get('bert_feats'))
            return output

    def validation_step(self, sample, batch_idx):
        self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \
                                  and hparams['use_post_flow']
        return super().validation_step(sample, batch_idx)

    def save_valid_result(self, sample, batch_idx, model_out):
        super(PortaSpeechFlowTask, self).save_valid_result(sample, batch_idx, model_out)
        sr = hparams['audio_sample_rate']
        f0_gt = None
        if sample.get('f0') is not None:
            f0_gt = denorm_f0(sample['f0'][0].cpu(), sample['uv'][0].cpu())
        if self.global_step > 0:
            # save FVAE result
            if hparams['use_post_flow']:
                wav_pred = self.vocoder.spec2wav(model_out['mel_out_fvae'][0].cpu(), f0=f0_gt)
                self.logger.add_audio(f'wav_fvae_{batch_idx}', wav_pred, self.global_step, sr)
                self.plot_mel(batch_idx, sample['mels'], model_out['mel_out_fvae'][0],
                              f'mel_fvae_{batch_idx}', f0s=f0_gt)

    def build_optimizer(self, model):
        if hparams['two_stage'] and hparams['use_post_flow']:
            self.optimizer = torch.optim.AdamW(
                [p for name, p in self.model.named_parameters() if 'post_flow' not in name],
                lr=hparams['lr'],
                betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
                weight_decay=hparams['weight_decay'])
            self.post_flow_optimizer = torch.optim.AdamW(
                self.model.post_flow.parameters(),
                lr=hparams['post_flow_lr'],
                betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
                weight_decay=hparams['weight_decay'])
            return [self.optimizer, self.post_flow_optimizer]
        else:
            self.optimizer = torch.optim.AdamW(
                self.model.parameters(),
                lr=hparams['lr'],
                betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
                weight_decay=hparams['weight_decay'])
            return [self.optimizer]

    def build_scheduler(self, optimizer):
        return FastSpeechTask.build_scheduler(self, optimizer[0])