|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import torch.nn.functional as F |
|
|
|
from modules.transformer.Models import Encoder, Decoder |
|
from modules.transformer.Layers import PostNet |
|
from collections import OrderedDict |
|
from models.tts.jets.alignments import ( |
|
AlignmentModule, |
|
viterbi_decode, |
|
average_by_duration, |
|
make_pad_mask, |
|
make_non_pad_mask, |
|
get_random_segments, |
|
) |
|
from models.tts.jets.length_regulator import GaussianUpsampling |
|
from models.vocoders.gan.generator.hifigan import HiFiGAN |
|
import os |
|
import json |
|
|
|
from utils.util import load_config |
|
|
|
|
|
def get_mask_from_lengths(lengths, max_len=None): |
|
device = lengths.device |
|
batch_size = lengths.shape[0] |
|
if max_len is None: |
|
max_len = torch.max(lengths).item() |
|
|
|
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) |
|
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) |
|
|
|
return mask |
|
|
|
|
|
def pad(input_ele, mel_max_length=None): |
|
if mel_max_length: |
|
max_len = mel_max_length |
|
else: |
|
max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) |
|
|
|
out_list = list() |
|
for i, batch in enumerate(input_ele): |
|
if len(batch.shape) == 1: |
|
one_batch_padded = F.pad( |
|
batch, (0, max_len - batch.size(0)), "constant", 0.0 |
|
) |
|
elif len(batch.shape) == 2: |
|
one_batch_padded = F.pad( |
|
batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 |
|
) |
|
out_list.append(one_batch_padded) |
|
out_padded = torch.stack(out_list) |
|
return out_padded |
|
|
|
|
|
class VarianceAdaptor(nn.Module): |
|
"""Variance Adaptor""" |
|
|
|
def __init__(self, cfg): |
|
super(VarianceAdaptor, self).__init__() |
|
self.duration_predictor = VariancePredictor(cfg) |
|
self.length_regulator = LengthRegulator() |
|
self.pitch_predictor = VariancePredictor(cfg) |
|
self.energy_predictor = VariancePredictor(cfg) |
|
|
|
|
|
if cfg.preprocess.use_frame_pitch: |
|
self.pitch_feature_level = "frame_level" |
|
self.pitch_dir = cfg.preprocess.pitch_dir |
|
else: |
|
self.pitch_feature_level = "phoneme_level" |
|
self.pitch_dir = cfg.preprocess.phone_pitch_dir |
|
|
|
if cfg.preprocess.use_frame_energy: |
|
self.energy_feature_level = "frame_level" |
|
self.energy_dir = cfg.preprocess.energy_dir |
|
else: |
|
self.energy_feature_level = "phoneme_level" |
|
self.energy_dir = cfg.preprocess.phone_energy_dir |
|
|
|
assert self.pitch_feature_level in ["phoneme_level", "frame_level"] |
|
assert self.energy_feature_level in ["phoneme_level", "frame_level"] |
|
|
|
pitch_quantization = cfg.model.variance_embedding.pitch_quantization |
|
energy_quantization = cfg.model.variance_embedding.energy_quantization |
|
n_bins = cfg.model.variance_embedding.n_bins |
|
assert pitch_quantization in ["linear", "log"] |
|
assert energy_quantization in ["linear", "log"] |
|
|
|
with open( |
|
os.path.join( |
|
cfg.preprocess.processed_dir, |
|
cfg.dataset[0], |
|
self.energy_dir, |
|
"statistics.json", |
|
) |
|
) as f: |
|
stats = json.load(f) |
|
stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]] |
|
mean, std = ( |
|
stats["voiced_positions"]["mean"], |
|
stats["voiced_positions"]["std"], |
|
) |
|
energy_min = (stats["total_positions"]["min"] - mean) / std |
|
energy_max = (stats["total_positions"]["max"] - mean) / std |
|
|
|
with open( |
|
os.path.join( |
|
cfg.preprocess.processed_dir, |
|
cfg.dataset[0], |
|
self.pitch_dir, |
|
"statistics.json", |
|
) |
|
) as f: |
|
stats = json.load(f) |
|
stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]] |
|
mean, std = ( |
|
stats["voiced_positions"]["mean"], |
|
stats["voiced_positions"]["std"], |
|
) |
|
pitch_min = (stats["total_positions"]["min"] - mean) / std |
|
pitch_max = (stats["total_positions"]["max"] - mean) / std |
|
|
|
if pitch_quantization == "log": |
|
self.pitch_bins = nn.Parameter( |
|
torch.exp( |
|
torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1) |
|
), |
|
requires_grad=False, |
|
) |
|
else: |
|
self.pitch_bins = nn.Parameter( |
|
torch.linspace(pitch_min, pitch_max, n_bins - 1), |
|
requires_grad=False, |
|
) |
|
if energy_quantization == "log": |
|
self.energy_bins = nn.Parameter( |
|
torch.exp( |
|
torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1) |
|
), |
|
requires_grad=False, |
|
) |
|
else: |
|
self.energy_bins = nn.Parameter( |
|
torch.linspace(energy_min, energy_max, n_bins - 1), |
|
requires_grad=False, |
|
) |
|
|
|
self.pitch_embedding = nn.Embedding( |
|
n_bins, cfg.model.transformer.encoder_hidden |
|
) |
|
self.energy_embedding = nn.Embedding( |
|
n_bins, cfg.model.transformer.encoder_hidden |
|
) |
|
|
|
def get_pitch_embedding(self, x, target, mask, control): |
|
prediction = self.pitch_predictor(x, mask) |
|
if target is not None: |
|
embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins)) |
|
else: |
|
prediction = prediction * control |
|
embedding = self.pitch_embedding( |
|
torch.bucketize(prediction, self.pitch_bins) |
|
) |
|
return prediction, embedding |
|
|
|
def get_energy_embedding(self, x, target, mask, control): |
|
prediction = self.energy_predictor(x, mask) |
|
if target is not None: |
|
embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins)) |
|
else: |
|
prediction = prediction * control |
|
embedding = self.energy_embedding( |
|
torch.bucketize(prediction, self.energy_bins) |
|
) |
|
return prediction, embedding |
|
|
|
def forward( |
|
self, |
|
x, |
|
src_mask, |
|
mel_mask=None, |
|
max_len=None, |
|
pitch_target=None, |
|
energy_target=None, |
|
duration_target=None, |
|
p_control=1.0, |
|
e_control=1.0, |
|
d_control=1.0, |
|
pitch_embedding=None, |
|
energy_embedding=None, |
|
): |
|
log_duration_prediction = self.duration_predictor(x, src_mask) |
|
|
|
x = x + pitch_embedding |
|
x = x + energy_embedding |
|
|
|
pitch_prediction = self.pitch_predictor(x, src_mask) |
|
energy_prediction = self.energy_predictor(x, src_mask) |
|
|
|
if duration_target is not None: |
|
x, mel_len = self.length_regulator(x, duration_target, max_len) |
|
duration_rounded = duration_target |
|
else: |
|
duration_rounded = torch.clamp( |
|
(torch.round(torch.exp(log_duration_prediction) - 1) * d_control), |
|
min=0, |
|
) |
|
x, mel_len = self.length_regulator(x, duration_rounded, max_len) |
|
mel_mask = get_mask_from_lengths(mel_len) |
|
|
|
return ( |
|
x, |
|
pitch_prediction, |
|
energy_prediction, |
|
log_duration_prediction, |
|
duration_rounded, |
|
mel_len, |
|
mel_mask, |
|
) |
|
|
|
def inference( |
|
self, |
|
x, |
|
src_mask, |
|
mel_mask=None, |
|
max_len=None, |
|
pitch_target=None, |
|
energy_target=None, |
|
duration_target=None, |
|
p_control=1.0, |
|
e_control=1.0, |
|
d_control=1.0, |
|
pitch_embedding=None, |
|
energy_embedding=None, |
|
): |
|
|
|
p_outs = self.pitch_predictor(x, src_mask) |
|
e_outs = self.energy_predictor(x, src_mask) |
|
d_outs = self.duration_predictor(x, src_mask) |
|
|
|
return p_outs, e_outs, d_outs |
|
|
|
|
|
class LengthRegulator(nn.Module): |
|
"""Length Regulator""" |
|
|
|
def __init__(self): |
|
super(LengthRegulator, self).__init__() |
|
|
|
def LR(self, x, duration, max_len): |
|
device = x.device |
|
output = list() |
|
mel_len = list() |
|
for batch, expand_target in zip(x, duration): |
|
expanded = self.expand(batch, expand_target) |
|
output.append(expanded) |
|
mel_len.append(expanded.shape[0]) |
|
|
|
if max_len is not None: |
|
output = pad(output, max_len) |
|
else: |
|
output = pad(output) |
|
|
|
return output, torch.LongTensor(mel_len).to(device) |
|
|
|
def expand(self, batch, predicted): |
|
out = list() |
|
|
|
for i, vec in enumerate(batch): |
|
expand_size = predicted[i].item() |
|
out.append(vec.expand(max(int(expand_size), 0), -1)) |
|
out = torch.cat(out, 0) |
|
|
|
return out |
|
|
|
def forward(self, x, duration, max_len): |
|
output, mel_len = self.LR(x, duration, max_len) |
|
return output, mel_len |
|
|
|
|
|
class VariancePredictor(nn.Module): |
|
"""Duration, Pitch and Energy Predictor""" |
|
|
|
def __init__(self, cfg): |
|
super(VariancePredictor, self).__init__() |
|
|
|
self.input_size = cfg.model.transformer.encoder_hidden |
|
self.filter_size = cfg.model.variance_predictor.filter_size |
|
self.kernel = cfg.model.variance_predictor.kernel_size |
|
self.conv_output_size = cfg.model.variance_predictor.filter_size |
|
self.dropout = cfg.model.variance_predictor.dropout |
|
|
|
self.conv_layer = nn.Sequential( |
|
OrderedDict( |
|
[ |
|
( |
|
"conv1d_1", |
|
Conv( |
|
self.input_size, |
|
self.filter_size, |
|
kernel_size=self.kernel, |
|
padding=(self.kernel - 1) // 2, |
|
), |
|
), |
|
("relu_1", nn.ReLU()), |
|
("layer_norm_1", nn.LayerNorm(self.filter_size)), |
|
("dropout_1", nn.Dropout(self.dropout)), |
|
( |
|
"conv1d_2", |
|
Conv( |
|
self.filter_size, |
|
self.filter_size, |
|
kernel_size=self.kernel, |
|
padding=1, |
|
), |
|
), |
|
("relu_2", nn.ReLU()), |
|
("layer_norm_2", nn.LayerNorm(self.filter_size)), |
|
("dropout_2", nn.Dropout(self.dropout)), |
|
] |
|
) |
|
) |
|
|
|
self.linear_layer = nn.Linear(self.conv_output_size, 1) |
|
|
|
def forward(self, encoder_output, mask): |
|
out = self.conv_layer(encoder_output) |
|
out = self.linear_layer(out) |
|
out = out.squeeze(-1) |
|
|
|
if mask is not None: |
|
out = out.masked_fill(mask, 0.0) |
|
|
|
return out |
|
|
|
|
|
class Conv(nn.Module): |
|
""" |
|
Convolution Module |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
dilation=1, |
|
bias=True, |
|
w_init="linear", |
|
): |
|
""" |
|
:param in_channels: dimension of input |
|
:param out_channels: dimension of output |
|
:param kernel_size: size of kernel |
|
:param stride: size of stride |
|
:param padding: size of padding |
|
:param dilation: dilation rate |
|
:param bias: boolean. if True, bias is included. |
|
:param w_init: str. weight inits with xavier initialization. |
|
""" |
|
super(Conv, self).__init__() |
|
|
|
self.conv = nn.Conv1d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias, |
|
) |
|
|
|
def forward(self, x): |
|
x = x.contiguous().transpose(1, 2) |
|
x = self.conv(x) |
|
x = x.contiguous().transpose(1, 2) |
|
|
|
return x |
|
|
|
|
|
class Jets(nn.Module): |
|
def __init__(self, cfg) -> None: |
|
super(Jets, self).__init__() |
|
self.cfg = cfg |
|
self.encoder = Encoder(cfg.model) |
|
self.variance_adaptor = VarianceAdaptor(cfg) |
|
self.decoder = Decoder(cfg.model) |
|
self.length_regulator_infer = LengthRegulator() |
|
self.mel_linear = nn.Linear( |
|
cfg.model.transformer.decoder_hidden, |
|
cfg.preprocess.n_mel, |
|
) |
|
self.postnet = PostNet(n_mel_channels=cfg.preprocess.n_mel) |
|
|
|
self.speaker_emb = None |
|
if cfg.train.multi_speaker_training: |
|
with open( |
|
os.path.join( |
|
cfg.preprocess.processed_dir, cfg.dataset[0], "spk2id.json" |
|
), |
|
"r", |
|
) as f: |
|
n_speaker = len(json.load(f)) |
|
self.speaker_emb = nn.Embedding( |
|
n_speaker, |
|
cfg.model.transformer.encoder_hidden, |
|
) |
|
|
|
output_dim = cfg.preprocess.n_mel |
|
attention_dim = 256 |
|
self.alignment_module = AlignmentModule(attention_dim, output_dim) |
|
|
|
|
|
pitch_embed_kernel_size: int = 9 |
|
pitch_embed_dropout: float = 0.5 |
|
self.pitch_embed = torch.nn.Sequential( |
|
torch.nn.Conv1d( |
|
in_channels=1, |
|
out_channels=attention_dim, |
|
kernel_size=pitch_embed_kernel_size, |
|
padding=(pitch_embed_kernel_size - 1) // 2, |
|
), |
|
torch.nn.Dropout(pitch_embed_dropout), |
|
) |
|
|
|
|
|
energy_embed_kernel_size: int = 9 |
|
energy_embed_dropout: float = 0.5 |
|
self.energy_embed = torch.nn.Sequential( |
|
torch.nn.Conv1d( |
|
in_channels=1, |
|
out_channels=attention_dim, |
|
kernel_size=energy_embed_kernel_size, |
|
padding=(energy_embed_kernel_size - 1) // 2, |
|
), |
|
torch.nn.Dropout(energy_embed_dropout), |
|
) |
|
|
|
|
|
self.length_regulator = GaussianUpsampling() |
|
|
|
self.segment_size = cfg.train.segment_size |
|
|
|
|
|
hifi_cfg = load_config("egs/vocoder/gan/hifigan/exp_config.json") |
|
|
|
hifi_cfg.preprocess.n_mel = attention_dim |
|
self.generator = HiFiGAN(hifi_cfg) |
|
|
|
def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: |
|
"""Make masks for self-attention. |
|
|
|
Args: |
|
ilens (LongTensor): Batch of lengths (B,). |
|
|
|
Returns: |
|
Tensor: Mask tensor for self-attention. |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
|
|
|
Examples: |
|
>>> ilens = [5, 3] |
|
>>> self._source_mask(ilens) |
|
tensor([[[1, 1, 1, 1, 1], |
|
[1, 1, 1, 0, 0]]], dtype=torch.uint8) |
|
|
|
""" |
|
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) |
|
return x_masks.unsqueeze(-2) |
|
|
|
def forward(self, data, p_control=1.0, e_control=1.0, d_control=1.0): |
|
speakers = data["spk_id"] |
|
texts = data["texts"] |
|
src_lens = data["text_len"] |
|
max_src_len = max(src_lens) |
|
feats = data["mel"] |
|
mel_lens = data["target_len"] if "target_len" in data else None |
|
feats_lengths = mel_lens |
|
max_mel_len = max(mel_lens) if "target_len" in data else None |
|
p_targets = data["pitch"] if "pitch" in data else None |
|
e_targets = data["energy"] if "energy" in data else None |
|
src_masks = get_mask_from_lengths(src_lens, max_src_len) |
|
mel_masks = ( |
|
get_mask_from_lengths(mel_lens, max_mel_len) |
|
if mel_lens is not None |
|
else None |
|
) |
|
|
|
output = self.encoder(texts, src_masks) |
|
|
|
if self.speaker_emb is not None: |
|
output = output + self.speaker_emb(speakers).unsqueeze(1).expand( |
|
-1, max_src_len, -1 |
|
) |
|
|
|
|
|
h_masks = make_pad_mask(src_lens).to(output.device) |
|
log_p_attn = self.alignment_module( |
|
output, feats, src_lens, feats_lengths, h_masks |
|
) |
|
ds, bin_loss = viterbi_decode(log_p_attn, src_lens, feats_lengths) |
|
ps = average_by_duration( |
|
ds, p_targets.squeeze(-1), src_lens, feats_lengths |
|
).unsqueeze(-1) |
|
es = average_by_duration( |
|
ds, e_targets.squeeze(-1), src_lens, feats_lengths |
|
).unsqueeze(-1) |
|
p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) |
|
e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) |
|
|
|
|
|
( |
|
output, |
|
p_predictions, |
|
e_predictions, |
|
log_d_predictions, |
|
d_rounded, |
|
mel_lens, |
|
mel_masks, |
|
) = self.variance_adaptor( |
|
output, |
|
src_masks, |
|
mel_masks, |
|
max_mel_len, |
|
p_targets, |
|
e_targets, |
|
ds, |
|
p_control, |
|
e_control, |
|
d_control, |
|
ps, |
|
es, |
|
) |
|
|
|
|
|
zs, _ = self.decoder(output, mel_masks) |
|
|
|
|
|
z_segments, z_start_idxs = get_random_segments( |
|
zs.transpose(1, 2), |
|
feats_lengths, |
|
self.segment_size, |
|
) |
|
|
|
|
|
wav = self.generator(z_segments) |
|
|
|
return ( |
|
wav, |
|
bin_loss, |
|
log_p_attn, |
|
z_start_idxs, |
|
log_d_predictions, |
|
ds, |
|
p_predictions, |
|
ps, |
|
e_predictions, |
|
es, |
|
src_lens, |
|
feats_lengths, |
|
) |
|
|
|
def inference(self, data, p_control=1.0, e_control=1.0, d_control=1.0): |
|
speakers = data["spk_id"] |
|
texts = data["texts"] |
|
src_lens = data["text_len"] |
|
max_src_len = max(src_lens) |
|
mel_lens = data["target_len"] if "target_len" in data else None |
|
feats_lengths = mel_lens |
|
max_mel_len = max(mel_lens) if "target_len" in data else None |
|
p_targets = data["pitch"] if "pitch" in data else None |
|
e_targets = data["energy"] if "energy" in data else None |
|
d_targets = data["durations"] if "durations" in data else None |
|
src_masks = get_mask_from_lengths(src_lens, max_src_len) |
|
mel_masks = ( |
|
get_mask_from_lengths(mel_lens, max_mel_len) |
|
if mel_lens is not None |
|
else None |
|
) |
|
|
|
x_masks = self._source_mask(src_lens) |
|
hs = self.encoder(texts, src_masks) |
|
|
|
( |
|
p_outs, |
|
e_outs, |
|
d_outs, |
|
) = self.variance_adaptor.inference( |
|
hs, |
|
src_masks, |
|
) |
|
|
|
p_embs = self.pitch_embed(p_outs.unsqueeze(-1).transpose(1, 2)).transpose(1, 2) |
|
e_embs = self.energy_embed(e_outs.unsqueeze(-1).transpose(1, 2)).transpose(1, 2) |
|
hs = hs + e_embs + p_embs |
|
|
|
|
|
offset = 1.0 |
|
d_predictions = torch.clamp( |
|
torch.round(d_outs.exp() - offset), min=0 |
|
).long() |
|
|
|
|
|
hs, mel_len = self.length_regulator_infer(hs, d_predictions, max_mel_len) |
|
mel_mask = get_mask_from_lengths(mel_len) |
|
zs, _ = self.decoder(hs, mel_mask) |
|
|
|
|
|
wav = self.generator(zs.transpose(1, 2)) |
|
|
|
return wav, d_predictions |
|
|