|
import matplotlib |
|
|
|
matplotlib.use('Agg') |
|
|
|
from utils import audio |
|
import matplotlib.pyplot as plt |
|
from preprocessing.data_gen_utils import get_pitch_parselmouth |
|
from training.dataset.fs2_utils import FastSpeechDataset |
|
from utils.cwt import cwt2f0 |
|
from utils.pl_utils import data_loader |
|
import os |
|
from multiprocessing.pool import Pool |
|
from tqdm import tqdm |
|
from modules.fastspeech.tts_modules import mel2ph_to_dur |
|
from utils.hparams import hparams |
|
from utils.plot import spec_to_figure, dur_to_figure, f0_to_figure |
|
from utils.pitch_utils import denorm_f0 |
|
from modules.fastspeech.fs2 import FastSpeech2 |
|
from training.task.tts import TtsTask |
|
import torch |
|
import torch.optim |
|
import torch.utils.data |
|
import torch.nn.functional as F |
|
import utils |
|
import torch.distributions |
|
import numpy as np |
|
from modules.commons.ssim import ssim |
|
|
|
class FastSpeech2Task(TtsTask): |
|
def __init__(self): |
|
super(FastSpeech2Task, self).__init__() |
|
self.dataset_cls = FastSpeechDataset |
|
self.mse_loss_fn = torch.nn.MSELoss() |
|
mel_losses = hparams['mel_loss'].split("|") |
|
self.loss_and_lambda = {} |
|
for i, l in enumerate(mel_losses): |
|
if l == '': |
|
continue |
|
if ':' in l: |
|
l, lbd = l.split(":") |
|
lbd = float(lbd) |
|
else: |
|
lbd = 1.0 |
|
self.loss_and_lambda[l] = lbd |
|
print("| Mel losses:", self.loss_and_lambda) |
|
|
|
|
|
@data_loader |
|
def train_dataloader(self): |
|
train_dataset = self.dataset_cls(hparams['train_set_name'], shuffle=True) |
|
return self.build_dataloader(train_dataset, True, self.max_tokens, self.max_sentences, |
|
endless=hparams['endless_ds']) |
|
|
|
@data_loader |
|
def val_dataloader(self): |
|
valid_dataset = self.dataset_cls(hparams['valid_set_name'], shuffle=False) |
|
return self.build_dataloader(valid_dataset, False, self.max_eval_tokens, self.max_eval_sentences) |
|
|
|
@data_loader |
|
def test_dataloader(self): |
|
test_dataset = self.dataset_cls(hparams['test_set_name'], shuffle=False) |
|
return self.build_dataloader(test_dataset, False, self.max_eval_tokens, |
|
self.max_eval_sentences, batch_by_size=False) |
|
|
|
def build_tts_model(self): |
|
''' |
|
rewrite |
|
''' |
|
return |
|
|
|
|
|
def build_model(self): |
|
self.build_tts_model() |
|
if hparams['load_ckpt'] != '': |
|
self.load_ckpt(hparams['load_ckpt'], strict=True) |
|
utils.print_arch(self.model) |
|
return self.model |
|
|
|
def _training_step(self, sample, batch_idx, _): |
|
''' |
|
rewrite |
|
''' |
|
return |
|
|
|
|
|
|
|
|
|
|
|
def validation_step(self, sample, batch_idx): |
|
''' |
|
rewrite |
|
''' |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _validation_end(self, outputs): |
|
all_losses_meter = { |
|
'total_loss': utils.AvgrageMeter(), |
|
} |
|
for output in outputs: |
|
n = output['nsamples'] |
|
for k, v in output['losses'].items(): |
|
if k not in all_losses_meter: |
|
all_losses_meter[k] = utils.AvgrageMeter() |
|
all_losses_meter[k].update(v, n) |
|
all_losses_meter['total_loss'].update(output['total_loss'], n) |
|
return {k: round(v.avg, 4) for k, v in all_losses_meter.items()} |
|
|
|
def run_model(self, model, sample, return_output=False): |
|
''' |
|
rewrite |
|
''' |
|
return |
|
txt_tokens = sample['txt_tokens'] |
|
target = sample['mels'] |
|
mel2ph = sample['mel2ph'] |
|
f0 = sample['f0'] |
|
uv = sample['uv'] |
|
energy = sample['energy'] |
|
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') |
|
if hparams['pitch_type'] == 'cwt': |
|
cwt_spec = sample[f'cwt_spec'] |
|
f0_mean = sample['f0_mean'] |
|
f0_std = sample['f0_std'] |
|
sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) |
|
|
|
output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, |
|
ref_mels=target, f0=f0, uv=uv, energy=energy, infer=False) |
|
|
|
losses = {} |
|
self.add_mel_loss(output['mel_out'], target, losses) |
|
self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) |
|
if hparams['use_pitch_embed']: |
|
self.add_pitch_loss(output, sample, losses) |
|
if hparams['use_energy_embed']: |
|
self.add_energy_loss(output['energy_pred'], energy, losses) |
|
if not return_output: |
|
return losses |
|
else: |
|
return losses, output |
|
|
|
|
|
|
|
|
|
def add_mel_loss(self, mel_out, target, losses, postfix='', mel_mix_loss=None): |
|
if mel_mix_loss is None: |
|
for loss_name, lbd in self.loss_and_lambda.items(): |
|
if 'l1' == loss_name: |
|
l = self.l1_loss(mel_out, target) |
|
elif 'mse' == loss_name: |
|
raise NotImplementedError |
|
elif 'ssim' == loss_name: |
|
l = self.ssim_loss(mel_out, target) |
|
elif 'gdl' == loss_name: |
|
raise NotImplementedError |
|
losses[f'{loss_name}{postfix}'] = l * lbd |
|
else: |
|
raise NotImplementedError |
|
|
|
def l1_loss(self, decoder_output, target): |
|
|
|
|
|
l1_loss = F.l1_loss(decoder_output, target, reduction='none') |
|
weights = self.weights_nonzero_speech(target) |
|
l1_loss = (l1_loss * weights).sum() / weights.sum() |
|
return l1_loss |
|
|
|
def ssim_loss(self, decoder_output, target, bias=6.0): |
|
|
|
|
|
assert decoder_output.shape == target.shape |
|
weights = self.weights_nonzero_speech(target) |
|
decoder_output = decoder_output[:, None] + bias |
|
target = target[:, None] + bias |
|
ssim_loss = 1 - ssim(decoder_output, target, size_average=False) |
|
ssim_loss = (ssim_loss * weights).sum() / weights.sum() |
|
return ssim_loss |
|
|
|
def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, losses=None): |
|
""" |
|
|
|
:param dur_pred: [B, T], float, log scale |
|
:param mel2ph: [B, T] |
|
:param txt_tokens: [B, T] |
|
:param losses: |
|
:return: |
|
""" |
|
B, T = txt_tokens.shape |
|
nonpadding = (txt_tokens != 0).float() |
|
dur_gt = mel2ph_to_dur(mel2ph, T).float() * nonpadding |
|
is_sil = torch.zeros_like(txt_tokens).bool() |
|
for p in self.sil_ph: |
|
is_sil = is_sil | (txt_tokens == self.phone_encoder.encode(p)[0]) |
|
is_sil = is_sil.float() |
|
|
|
|
|
if hparams['dur_loss'] == 'mse': |
|
losses['pdur'] = F.mse_loss(dur_pred, (dur_gt + 1).log(), reduction='none') |
|
losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum() |
|
dur_pred = (dur_pred.exp() - 1).clamp(min=0) |
|
elif hparams['dur_loss'] == 'mog': |
|
return NotImplementedError |
|
elif hparams['dur_loss'] == 'crf': |
|
losses['pdur'] = -self.model.dur_predictor.crf( |
|
dur_pred, dur_gt.long().clamp(min=0, max=31), mask=nonpadding > 0, reduction='mean') |
|
losses['pdur'] = losses['pdur'] * hparams['lambda_ph_dur'] |
|
|
|
|
|
if hparams['lambda_word_dur'] > 0: |
|
word_id = (is_sil.cumsum(-1) * (1 - is_sil)).long() |
|
word_dur_p = dur_pred.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_pred)[:, 1:] |
|
word_dur_g = dur_gt.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_gt)[:, 1:] |
|
wdur_loss = F.mse_loss((word_dur_p + 1).log(), (word_dur_g + 1).log(), reduction='none') |
|
word_nonpadding = (word_dur_g > 0).float() |
|
wdur_loss = (wdur_loss * word_nonpadding).sum() / word_nonpadding.sum() |
|
losses['wdur'] = wdur_loss * hparams['lambda_word_dur'] |
|
if hparams['lambda_sent_dur'] > 0: |
|
sent_dur_p = dur_pred.sum(-1) |
|
sent_dur_g = dur_gt.sum(-1) |
|
sdur_loss = F.mse_loss((sent_dur_p + 1).log(), (sent_dur_g + 1).log(), reduction='mean') |
|
losses['sdur'] = sdur_loss.mean() * hparams['lambda_sent_dur'] |
|
|
|
def add_pitch_loss(self, output, sample, losses): |
|
if hparams['pitch_type'] == 'ph': |
|
nonpadding = (sample['txt_tokens'] != 0).float() |
|
pitch_loss_fn = F.l1_loss if hparams['pitch_loss'] == 'l1' else F.mse_loss |
|
losses['f0'] = (pitch_loss_fn(output['pitch_pred'][:, :, 0], sample['f0'], |
|
reduction='none') * nonpadding).sum() \ |
|
/ nonpadding.sum() * hparams['lambda_f0'] |
|
return |
|
mel2ph = sample['mel2ph'] |
|
f0 = sample['f0'] |
|
uv = sample['uv'] |
|
nonpadding = (mel2ph != 0).float() |
|
if hparams['pitch_type'] == 'cwt': |
|
cwt_spec = sample[f'cwt_spec'] |
|
f0_mean = sample['f0_mean'] |
|
f0_std = sample['f0_std'] |
|
cwt_pred = output['cwt'][:, :, :10] |
|
f0_mean_pred = output['f0_mean'] |
|
f0_std_pred = output['f0_std'] |
|
losses['C'] = self.cwt_loss(cwt_pred, cwt_spec) * hparams['lambda_f0'] |
|
if hparams['use_uv']: |
|
assert output['cwt'].shape[-1] == 11 |
|
uv_pred = output['cwt'][:, :, -1] |
|
losses['uv'] = (F.binary_cross_entropy_with_logits(uv_pred, uv, reduction='none') * nonpadding) \ |
|
.sum() / nonpadding.sum() * hparams['lambda_uv'] |
|
losses['f0_mean'] = F.l1_loss(f0_mean_pred, f0_mean) * hparams['lambda_f0'] |
|
losses['f0_std'] = F.l1_loss(f0_std_pred, f0_std) * hparams['lambda_f0'] |
|
if hparams['cwt_add_f0_loss']: |
|
f0_cwt_ = self.model.cwt2f0_norm(cwt_pred, f0_mean_pred, f0_std_pred, mel2ph) |
|
self.add_f0_loss(f0_cwt_[:, :, None], f0, uv, losses, nonpadding=nonpadding) |
|
elif hparams['pitch_type'] == 'frame': |
|
self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding) |
|
|
|
def add_f0_loss(self, p_pred, f0, uv, losses, nonpadding): |
|
assert p_pred[..., 0].shape == f0.shape |
|
if hparams['use_uv']: |
|
assert p_pred[..., 1].shape == uv.shape |
|
losses['uv'] = (F.binary_cross_entropy_with_logits( |
|
p_pred[:, :, 1], uv, reduction='none') * nonpadding).sum() \ |
|
/ nonpadding.sum() * hparams['lambda_uv'] |
|
nonpadding = nonpadding * (uv == 0).float() |
|
|
|
f0_pred = p_pred[:, :, 0] |
|
if hparams['pitch_loss'] in ['l1', 'l2']: |
|
pitch_loss_fn = F.l1_loss if hparams['pitch_loss'] == 'l1' else F.mse_loss |
|
losses['f0'] = (pitch_loss_fn(f0_pred, f0, reduction='none') * nonpadding).sum() \ |
|
/ nonpadding.sum() * hparams['lambda_f0'] |
|
elif hparams['pitch_loss'] == 'ssim': |
|
return NotImplementedError |
|
|
|
def cwt_loss(self, cwt_p, cwt_g): |
|
if hparams['cwt_loss'] == 'l1': |
|
return F.l1_loss(cwt_p, cwt_g) |
|
if hparams['cwt_loss'] == 'l2': |
|
return F.mse_loss(cwt_p, cwt_g) |
|
if hparams['cwt_loss'] == 'ssim': |
|
return self.ssim_loss(cwt_p, cwt_g, 20) |
|
|
|
def add_energy_loss(self, energy_pred, energy, losses): |
|
nonpadding = (energy != 0).float() |
|
loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum() |
|
loss = loss * hparams['lambda_energy'] |
|
losses['e'] = loss |
|
|
|
|
|
|
|
|
|
|
|
def plot_mel(self, batch_idx, spec, spec_out, name=None): |
|
spec_cat = torch.cat([spec, spec_out], -1) |
|
name = f'mel_{batch_idx}' if name is None else name |
|
vmin = hparams['mel_vmin'] |
|
vmax = hparams['mel_vmax'] |
|
self.logger.experiment.add_figure(name, spec_to_figure(spec_cat[0], vmin, vmax), self.global_step) |
|
|
|
def plot_dur(self, batch_idx, sample, model_out): |
|
T_txt = sample['txt_tokens'].shape[1] |
|
dur_gt = mel2ph_to_dur(sample['mel2ph'], T_txt)[0] |
|
dur_pred = self.model.dur_predictor.out2dur(model_out['dur']).float() |
|
txt = self.phone_encoder.decode(sample['txt_tokens'][0].cpu().numpy()) |
|
txt = txt.split(" ") |
|
self.logger.experiment.add_figure( |
|
f'dur_{batch_idx}', dur_to_figure(dur_gt, dur_pred, txt), self.global_step) |
|
|
|
def plot_pitch(self, batch_idx, sample, model_out): |
|
f0 = sample['f0'] |
|
if hparams['pitch_type'] == 'ph': |
|
mel2ph = sample['mel2ph'] |
|
f0 = self.expand_f0_ph(f0, mel2ph) |
|
f0_pred = self.expand_f0_ph(model_out['pitch_pred'][:, :, 0], mel2ph) |
|
self.logger.experiment.add_figure( |
|
f'f0_{batch_idx}', f0_to_figure(f0[0], None, f0_pred[0]), self.global_step) |
|
return |
|
f0 = denorm_f0(f0, sample['uv'], hparams) |
|
if hparams['pitch_type'] == 'cwt': |
|
|
|
cwt_out = model_out['cwt'] |
|
cwt_spec = cwt_out[:, :, :10] |
|
cwt = torch.cat([cwt_spec, sample['cwt_spec']], -1) |
|
self.logger.experiment.add_figure(f'cwt_{batch_idx}', spec_to_figure(cwt[0]), self.global_step) |
|
|
|
f0_pred = cwt2f0(cwt_spec, model_out['f0_mean'], model_out['f0_std'], hparams['cwt_scales']) |
|
if hparams['use_uv']: |
|
assert cwt_out.shape[-1] == 11 |
|
uv_pred = cwt_out[:, :, -1] > 0 |
|
f0_pred[uv_pred > 0] = 0 |
|
f0_cwt = denorm_f0(sample['f0_cwt'], sample['uv'], hparams) |
|
self.logger.experiment.add_figure( |
|
f'f0_{batch_idx}', f0_to_figure(f0[0], f0_cwt[0], f0_pred[0]), self.global_step) |
|
elif hparams['pitch_type'] == 'frame': |
|
|
|
|
|
pitch_pred = denorm_f0(model_out['pitch_pred'][:, :, 0], sample['uv'], hparams) |
|
self.logger.experiment.add_figure( |
|
f'f0_{batch_idx}', f0_to_figure(f0[0], None, pitch_pred[0]), self.global_step) |
|
|
|
|
|
|
|
|
|
def test_step(self, sample, batch_idx): |
|
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') |
|
hubert = sample['hubert'] |
|
mel2ph, uv, f0 = None, None, None |
|
ref_mels = None |
|
if hparams['profile_infer']: |
|
pass |
|
else: |
|
|
|
mel2ph = sample['mel2ph'] |
|
|
|
f0 = sample['f0'] |
|
uv = sample['uv'] |
|
|
|
if hparams.get('use_midi') is not None and hparams['use_midi']: |
|
outputs = self.model( |
|
hubert, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, ref_mels=ref_mels, infer=True) |
|
else: |
|
outputs = self.model( |
|
hubert, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, ref_mels=ref_mels, infer=True) |
|
sample['outputs'] = self.model.out2mel(outputs['mel_out']) |
|
sample['mel2ph_pred'] = outputs['mel2ph'] |
|
if hparams.get('pe_enable') is not None and hparams['pe_enable']: |
|
sample['f0'] = self.pe(sample['mels'])['f0_denorm_pred'] |
|
sample['f0_pred'] = self.pe(sample['outputs'])['f0_denorm_pred'] |
|
else: |
|
sample['f0'] = denorm_f0(sample['f0'], sample['uv'], hparams) |
|
sample['f0_pred'] = outputs.get('f0_denorm') |
|
return self.after_infer(sample) |
|
|
|
def after_infer(self, predictions): |
|
if self.saving_result_pool is None and not hparams['profile_infer']: |
|
self.saving_result_pool = Pool(min(int(os.getenv('N_PROC', os.cpu_count())), 16)) |
|
self.saving_results_futures = [] |
|
predictions = utils.unpack_dict_to_list(predictions) |
|
t = tqdm(predictions) |
|
for num_predictions, prediction in enumerate(t): |
|
for k, v in prediction.items(): |
|
if type(v) is torch.Tensor: |
|
prediction[k] = v.cpu().numpy() |
|
|
|
item_name = prediction.get('item_name') |
|
|
|
|
|
|
|
mel_gt = prediction["mels"] |
|
mel_gt_mask = np.abs(mel_gt).sum(-1) > 0 |
|
mel_gt = mel_gt[mel_gt_mask] |
|
mel2ph_gt = prediction.get("mel2ph") |
|
mel2ph_gt = mel2ph_gt[mel_gt_mask] if mel2ph_gt is not None else None |
|
mel_pred = prediction["outputs"] |
|
mel_pred_mask = np.abs(mel_pred).sum(-1) > 0 |
|
mel_pred = mel_pred[mel_pred_mask] |
|
mel_gt = np.clip(mel_gt, hparams['mel_vmin'], hparams['mel_vmax']) |
|
mel_pred = np.clip(mel_pred, hparams['mel_vmin'], hparams['mel_vmax']) |
|
|
|
mel2ph_pred = prediction.get("mel2ph_pred") |
|
if mel2ph_pred is not None: |
|
if len(mel2ph_pred) > len(mel_pred_mask): |
|
mel2ph_pred = mel2ph_pred[:len(mel_pred_mask)] |
|
mel2ph_pred = mel2ph_pred[mel_pred_mask] |
|
|
|
f0_gt = prediction.get("f0") |
|
f0_pred = f0_gt |
|
if f0_pred is not None: |
|
f0_gt = f0_gt[mel_gt_mask] |
|
if len(f0_pred) > len(mel_pred_mask): |
|
f0_pred = f0_pred[:len(mel_pred_mask)] |
|
f0_pred = f0_pred[mel_pred_mask] |
|
text=None |
|
str_phs = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gen_dir = os.path.join(hparams['work_dir'], |
|
f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') |
|
wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred) |
|
if not hparams['profile_infer']: |
|
os.makedirs(gen_dir, exist_ok=True) |
|
os.makedirs(f'{gen_dir}/wavs', exist_ok=True) |
|
os.makedirs(f'{gen_dir}/plot', exist_ok=True) |
|
os.makedirs(os.path.join(hparams['work_dir'], 'P_mels_npy'), exist_ok=True) |
|
os.makedirs(os.path.join(hparams['work_dir'], 'G_mels_npy'), exist_ok=True) |
|
self.saving_results_futures.append( |
|
self.saving_result_pool.apply_async(self.save_result, args=[ |
|
wav_pred, mel_pred, 'P', item_name, text, gen_dir, str_phs, mel2ph_pred, f0_gt, f0_pred])) |
|
|
|
if mel_gt is not None and hparams['save_gt']: |
|
wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0_gt) |
|
self.saving_results_futures.append( |
|
self.saving_result_pool.apply_async(self.save_result, args=[ |
|
wav_gt, mel_gt, 'G', item_name, text, gen_dir, str_phs, mel2ph_gt, f0_gt, f0_pred])) |
|
if hparams['save_f0']: |
|
import matplotlib.pyplot as plt |
|
|
|
f0_pred_ = f0_pred |
|
f0_gt_, _ = get_pitch_parselmouth(wav_gt, mel_gt, hparams) |
|
fig = plt.figure() |
|
plt.plot(f0_pred_, label=r'$f0_P$') |
|
plt.plot(f0_gt_, label=r'$f0_G$') |
|
if hparams.get('pe_enable') is not None and hparams['pe_enable']: |
|
|
|
|
|
|
|
pass |
|
plt.legend() |
|
plt.tight_layout() |
|
plt.savefig(f'{gen_dir}/plot/[F0][{item_name}]{text}.png', format='png') |
|
plt.close(fig) |
|
|
|
t.set_description( |
|
f"Pred_shape: {mel_pred.shape}, gt_shape: {mel_gt.shape}") |
|
else: |
|
if 'gen_wav_time' not in self.stats: |
|
self.stats['gen_wav_time'] = 0 |
|
self.stats['gen_wav_time'] += len(wav_pred) / hparams['audio_sample_rate'] |
|
print('gen_wav_time: ', self.stats['gen_wav_time']) |
|
|
|
return {} |
|
|
|
@staticmethod |
|
def save_result(wav_out, mel, prefix, item_name, text, gen_dir, str_phs=None, mel2ph=None, gt_f0=None, pred_f0=None): |
|
item_name = item_name.replace('/', '-') |
|
base_fn = f'[{item_name}][{prefix}]' |
|
|
|
if text is not None: |
|
base_fn += text |
|
base_fn += ('-' + hparams['exp_name']) |
|
np.save(os.path.join(hparams['work_dir'], f'{prefix}_mels_npy', item_name), mel) |
|
audio.save_wav(wav_out, f'{gen_dir}/wavs/{base_fn}.wav', 24000, |
|
norm=hparams['out_wav_norm']) |
|
fig = plt.figure(figsize=(14, 10)) |
|
spec_vmin = hparams['mel_vmin'] |
|
spec_vmax = hparams['mel_vmax'] |
|
heatmap = plt.pcolor(mel.T, vmin=spec_vmin, vmax=spec_vmax) |
|
fig.colorbar(heatmap) |
|
if hparams.get('pe_enable') is not None and hparams['pe_enable']: |
|
gt_f0 = (gt_f0 - 100) / (800 - 100) * 80 * (gt_f0 > 0) |
|
pred_f0 = (pred_f0 - 100) / (800 - 100) * 80 * (pred_f0 > 0) |
|
plt.plot(pred_f0, c='white', linewidth=1, alpha=0.6) |
|
plt.plot(gt_f0, c='red', linewidth=1, alpha=0.6) |
|
else: |
|
f0, _ = get_pitch_parselmouth(wav_out, mel, hparams) |
|
f0 = (f0 - 100) / (800 - 100) * 80 * (f0 > 0) |
|
plt.plot(f0, c='white', linewidth=1, alpha=0.6) |
|
if mel2ph is not None and str_phs is not None: |
|
decoded_txt = str_phs.split(" ") |
|
dur = mel2ph_to_dur(torch.LongTensor(mel2ph)[None, :], len(decoded_txt))[0].numpy() |
|
dur = [0] + list(np.cumsum(dur)) |
|
for i in range(len(dur) - 1): |
|
shift = (i % 20) + 1 |
|
plt.text(dur[i], shift, decoded_txt[i]) |
|
plt.hlines(shift, dur[i], dur[i + 1], colors='b' if decoded_txt[i] != '|' else 'black') |
|
plt.vlines(dur[i], 0, 5, colors='b' if decoded_txt[i] != '|' else 'black', |
|
alpha=1, linewidth=1) |
|
plt.tight_layout() |
|
plt.savefig(f'{gen_dir}/plot/{base_fn}.png', format='png', dpi=1000) |
|
plt.close(fig) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def expand_f0_ph(f0, mel2ph): |
|
f0 = denorm_f0(f0, None, hparams) |
|
f0 = F.pad(f0, [1, 0]) |
|
f0 = torch.gather(f0, 1, mel2ph) |
|
return f0 |
|
|
|
|
|
if __name__ == '__main__': |
|
FastSpeech2Task.start() |
|
|