|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Tuple |
|
from numba import jit |
|
from scipy.stats import betabinom |
|
|
|
|
|
class AlignmentModule(nn.Module): |
|
"""Alignment Learning Framework proposed for parallel TTS models in: |
|
|
|
https://arxiv.org/abs/2108.10447 |
|
|
|
""" |
|
|
|
def __init__(self, adim, odim, cache_prior=True): |
|
"""Initialize AlignmentModule. |
|
|
|
Args: |
|
adim (int): Dimension of attention. |
|
odim (int): Dimension of feats. |
|
cache_prior (bool): Whether to cache beta-binomial prior. |
|
|
|
""" |
|
super().__init__() |
|
self.cache_prior = cache_prior |
|
self._cache = {} |
|
|
|
self.t_conv1 = nn.Conv1d(adim, adim, kernel_size=3, padding=1) |
|
self.t_conv2 = nn.Conv1d(adim, adim, kernel_size=1, padding=0) |
|
|
|
self.f_conv1 = nn.Conv1d(odim, adim, kernel_size=3, padding=1) |
|
self.f_conv2 = nn.Conv1d(adim, adim, kernel_size=3, padding=1) |
|
self.f_conv3 = nn.Conv1d(adim, adim, kernel_size=1, padding=0) |
|
|
|
def forward(self, text, feats, text_lengths, feats_lengths, x_masks=None): |
|
"""Calculate alignment loss. |
|
|
|
Args: |
|
text (Tensor): Batched text embedding (B, T_text, adim). |
|
feats (Tensor): Batched acoustic feature (B, T_feats, odim). |
|
text_lengths (Tensor): Text length tensor (B,). |
|
feats_lengths (Tensor): Feature length tensor (B,). |
|
x_masks (Tensor): Mask tensor (B, T_text). |
|
|
|
Returns: |
|
Tensor: Log probability of attention matrix (B, T_feats, T_text). |
|
|
|
""" |
|
text = text.transpose(1, 2) |
|
text = F.relu(self.t_conv1(text)) |
|
text = self.t_conv2(text) |
|
text = text.transpose(1, 2) |
|
|
|
feats = feats.transpose(1, 2) |
|
feats = F.relu(self.f_conv1(feats)) |
|
feats = F.relu(self.f_conv2(feats)) |
|
feats = self.f_conv3(feats) |
|
feats = feats.transpose(1, 2) |
|
|
|
dist = feats.unsqueeze(2) - text.unsqueeze(1) |
|
dist = torch.norm(dist, p=2, dim=3) |
|
score = -dist |
|
|
|
if x_masks is not None: |
|
x_masks = x_masks.unsqueeze(-2) |
|
score = score.masked_fill(x_masks, -np.inf) |
|
|
|
log_p_attn = F.log_softmax(score, dim=-1) |
|
|
|
|
|
bb_prior = self._generate_prior( |
|
text_lengths, |
|
feats_lengths, |
|
).to(dtype=log_p_attn.dtype, device=log_p_attn.device) |
|
log_p_attn = log_p_attn + bb_prior |
|
|
|
return log_p_attn |
|
|
|
def _generate_prior(self, text_lengths, feats_lengths, w=1) -> torch.Tensor: |
|
"""Generate alignment prior formulated as beta-binomial distribution |
|
|
|
Args: |
|
text_lengths (Tensor): Batch of the lengths of each input (B,). |
|
feats_lengths (Tensor): Batch of the lengths of each target (B,). |
|
w (float): Scaling factor; lower -> wider the width. |
|
|
|
Returns: |
|
Tensor: Batched 2d static prior matrix (B, T_feats, T_text). |
|
|
|
""" |
|
B = len(text_lengths) |
|
T_text = text_lengths.max() |
|
T_feats = feats_lengths.max() |
|
|
|
bb_prior = torch.full((B, T_feats, T_text), fill_value=-np.inf) |
|
for bidx in range(B): |
|
T = feats_lengths[bidx].item() |
|
N = text_lengths[bidx].item() |
|
|
|
key = str(T) + "," + str(N) |
|
if self.cache_prior and key in self._cache: |
|
prob = self._cache[key] |
|
else: |
|
alpha = w * np.arange(1, T + 1, dtype=float) |
|
beta = w * np.array([T - t + 1 for t in alpha]) |
|
k = np.arange(N) |
|
batched_k = k[..., None] |
|
prob = betabinom.logpmf(batched_k, N, alpha, beta) |
|
|
|
|
|
if self.cache_prior and key not in self._cache: |
|
self._cache[key] = prob |
|
|
|
prob = torch.from_numpy(prob).transpose(0, 1) |
|
bb_prior[bidx, :T, :N] = prob |
|
|
|
return bb_prior |
|
|
|
|
|
@jit(nopython=True) |
|
def _monotonic_alignment_search(log_p_attn): |
|
|
|
T_mel = log_p_attn.shape[0] |
|
T_inp = log_p_attn.shape[1] |
|
Q = np.full((T_inp, T_mel), fill_value=-np.inf) |
|
|
|
log_prob = log_p_attn.transpose(1, 0) |
|
|
|
for j in range(T_mel): |
|
Q[0, j] = log_prob[0, : j + 1].sum() |
|
|
|
|
|
for j in range(1, T_mel): |
|
for i in range(1, min(j + 1, T_inp)): |
|
Q[i, j] = max(Q[i - 1, j - 1], Q[i, j - 1]) + log_prob[i, j] |
|
|
|
|
|
A = np.full((T_mel,), fill_value=T_inp - 1) |
|
for j in range(T_mel - 2, -1, -1): |
|
|
|
i_a = A[j + 1] - 1 |
|
i_b = A[j + 1] |
|
if i_b == 0: |
|
argmax_i = 0 |
|
elif Q[i_a, j] >= Q[i_b, j]: |
|
argmax_i = i_a |
|
else: |
|
argmax_i = i_b |
|
A[j] = argmax_i |
|
return A |
|
|
|
|
|
def viterbi_decode(log_p_attn, text_lengths, feats_lengths): |
|
"""Extract duration from an attention probability matrix |
|
|
|
Args: |
|
log_p_attn (Tensor): Batched log probability of attention |
|
matrix (B, T_feats, T_text). |
|
text_lengths (Tensor): Text length tensor (B,). |
|
feats_legnths (Tensor): Feature length tensor (B,). |
|
|
|
Returns: |
|
Tensor: Batched token duration extracted from `log_p_attn` (B, T_text). |
|
Tensor: Binarization loss tensor (). |
|
|
|
""" |
|
B = log_p_attn.size(0) |
|
T_text = log_p_attn.size(2) |
|
device = log_p_attn.device |
|
|
|
bin_loss = 0 |
|
ds = torch.zeros((B, T_text), device=device) |
|
for b in range(B): |
|
cur_log_p_attn = log_p_attn[b, : feats_lengths[b], : text_lengths[b]] |
|
viterbi = _monotonic_alignment_search(cur_log_p_attn.detach().cpu().numpy()) |
|
_ds = np.bincount(viterbi) |
|
ds[b, : len(_ds)] = torch.from_numpy(_ds).to(device) |
|
|
|
t_idx = torch.arange(feats_lengths[b]) |
|
bin_loss = bin_loss - cur_log_p_attn[t_idx, viterbi].mean() |
|
bin_loss = bin_loss / B |
|
return ds, bin_loss |
|
|
|
|
|
@jit(nopython=True) |
|
def _average_by_duration(ds, xs, text_lengths, feats_lengths): |
|
B = ds.shape[0] |
|
xs_avg = np.zeros_like(ds) |
|
ds = ds.astype(np.int32) |
|
for b in range(B): |
|
t_text = text_lengths[b] |
|
t_feats = feats_lengths[b] |
|
d = ds[b, :t_text] |
|
d_cumsum = d.cumsum() |
|
d_cumsum = [0] + list(d_cumsum) |
|
x = xs[b, :t_feats] |
|
for n, (start, end) in enumerate(zip(d_cumsum[:-1], d_cumsum[1:])): |
|
if len(x[start:end]) != 0: |
|
xs_avg[b, n] = x[start:end].mean() |
|
else: |
|
xs_avg[b, n] = 0 |
|
return xs_avg |
|
|
|
|
|
def average_by_duration(ds, xs, text_lengths, feats_lengths): |
|
"""Average frame-level features into token-level according to durations |
|
|
|
Args: |
|
ds (Tensor): Batched token duration (B, T_text). |
|
xs (Tensor): Batched feature sequences to be averaged (B, T_feats). |
|
text_lengths (Tensor): Text length tensor (B,). |
|
feats_lengths (Tensor): Feature length tensor (B,). |
|
|
|
Returns: |
|
Tensor: Batched feature averaged according to the token duration (B, T_text). |
|
|
|
""" |
|
device = ds.device |
|
args = [ds, xs, text_lengths, feats_lengths] |
|
args = [arg.detach().cpu().numpy() for arg in args] |
|
xs_avg = _average_by_duration(*args) |
|
xs_avg = torch.from_numpy(xs_avg).to(device) |
|
return xs_avg |
|
|
|
|
|
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): |
|
"""Make mask tensor containing indices of padded part. |
|
|
|
Args: |
|
lengths (LongTensor or List): Batch of lengths (B,). |
|
xs (Tensor, optional): The reference tensor. |
|
If set, masks will be the same shape as this tensor. |
|
length_dim (int, optional): Dimension indicator of the above tensor. |
|
See the example. |
|
|
|
Returns: |
|
Tensor: Mask tensor containing indices of padded part. |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
|
|
|
Examples: |
|
With only lengths. |
|
|
|
>>> lengths = [5, 3, 2] |
|
>>> make_pad_mask(lengths) |
|
masks = [[0, 0, 0, 0 ,0], |
|
[0, 0, 0, 1, 1], |
|
[0, 0, 1, 1, 1]] |
|
|
|
With the reference tensor. |
|
|
|
>>> xs = torch.zeros((3, 2, 4)) |
|
>>> make_pad_mask(lengths, xs) |
|
tensor([[[0, 0, 0, 0], |
|
[0, 0, 0, 0]], |
|
[[0, 0, 0, 1], |
|
[0, 0, 0, 1]], |
|
[[0, 0, 1, 1], |
|
[0, 0, 1, 1]]], dtype=torch.uint8) |
|
>>> xs = torch.zeros((3, 2, 6)) |
|
>>> make_pad_mask(lengths, xs) |
|
tensor([[[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1]], |
|
[[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1]], |
|
[[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) |
|
|
|
With the reference tensor and dimension indicator. |
|
|
|
>>> xs = torch.zeros((3, 6, 6)) |
|
>>> make_pad_mask(lengths, xs, 1) |
|
tensor([[[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[1, 1, 1, 1, 1, 1]], |
|
[[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1]], |
|
[[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) |
|
>>> make_pad_mask(lengths, xs, 2) |
|
tensor([[[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1], |
|
[0, 0, 0, 0, 0, 1]], |
|
[[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1], |
|
[0, 0, 0, 1, 1, 1]], |
|
[[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1], |
|
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) |
|
|
|
""" |
|
if length_dim == 0: |
|
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) |
|
|
|
if not isinstance(lengths, list): |
|
lengths = lengths.tolist() |
|
bs = int(len(lengths)) |
|
if maxlen is None: |
|
if xs is None: |
|
maxlen = int(max(lengths)) |
|
else: |
|
maxlen = xs.size(length_dim) |
|
else: |
|
assert xs is None |
|
assert maxlen >= int(max(lengths)) |
|
|
|
seq_range = torch.arange(0, maxlen, dtype=torch.int64) |
|
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) |
|
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) |
|
mask = seq_range_expand >= seq_length_expand |
|
|
|
if xs is not None: |
|
assert xs.size(0) == bs, (xs.size(0), bs) |
|
|
|
if length_dim < 0: |
|
length_dim = xs.dim() + length_dim |
|
|
|
ind = tuple( |
|
slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) |
|
) |
|
mask = mask[ind].expand_as(xs).to(xs.device) |
|
return mask |
|
|
|
|
|
def make_non_pad_mask(lengths, xs=None, length_dim=-1): |
|
"""Make mask tensor containing indices of non-padded part. |
|
|
|
Args: |
|
lengths (LongTensor or List): Batch of lengths (B,). |
|
xs (Tensor, optional): The reference tensor. |
|
If set, masks will be the same shape as this tensor. |
|
length_dim (int, optional): Dimension indicator of the above tensor. |
|
See the example. |
|
|
|
Returns: |
|
ByteTensor: mask tensor containing indices of padded part. |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
|
|
|
Examples: |
|
With only lengths. |
|
|
|
>>> lengths = [5, 3, 2] |
|
>>> make_non_pad_mask(lengths) |
|
masks = [[1, 1, 1, 1 ,1], |
|
[1, 1, 1, 0, 0], |
|
[1, 1, 0, 0, 0]] |
|
|
|
With the reference tensor. |
|
|
|
>>> xs = torch.zeros((3, 2, 4)) |
|
>>> make_non_pad_mask(lengths, xs) |
|
tensor([[[1, 1, 1, 1], |
|
[1, 1, 1, 1]], |
|
[[1, 1, 1, 0], |
|
[1, 1, 1, 0]], |
|
[[1, 1, 0, 0], |
|
[1, 1, 0, 0]]], dtype=torch.uint8) |
|
>>> xs = torch.zeros((3, 2, 6)) |
|
>>> make_non_pad_mask(lengths, xs) |
|
tensor([[[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0]], |
|
[[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0]], |
|
[[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) |
|
|
|
With the reference tensor and dimension indicator. |
|
|
|
>>> xs = torch.zeros((3, 6, 6)) |
|
>>> make_non_pad_mask(lengths, xs, 1) |
|
tensor([[[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[0, 0, 0, 0, 0, 0]], |
|
[[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0]], |
|
[[1, 1, 1, 1, 1, 1], |
|
[1, 1, 1, 1, 1, 1], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) |
|
>>> make_non_pad_mask(lengths, xs, 2) |
|
tensor([[[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1, 0]], |
|
[[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0, 0]], |
|
[[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) |
|
|
|
""" |
|
return ~make_pad_mask(lengths, xs, length_dim) |
|
|
|
|
|
def get_random_segments( |
|
x: torch.Tensor, |
|
x_lengths: torch.Tensor, |
|
segment_size: int, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Get random segments. |
|
|
|
Args: |
|
x (Tensor): Input tensor (B, C, T). |
|
x_lengths (Tensor): Length tensor (B,). |
|
segment_size (int): Segment size. |
|
|
|
Returns: |
|
Tensor: Segmented tensor (B, C, segment_size). |
|
Tensor: Start index tensor (B,). |
|
|
|
""" |
|
b, c, t = x.size() |
|
max_start_idx = x_lengths - segment_size |
|
start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to( |
|
dtype=torch.long, |
|
) |
|
segments = get_segments(x, start_idxs, segment_size) |
|
return segments, start_idxs |
|
|
|
|
|
def get_segments( |
|
x: torch.Tensor, |
|
start_idxs: torch.Tensor, |
|
segment_size: int, |
|
) -> torch.Tensor: |
|
"""Get segments. |
|
|
|
Args: |
|
x (Tensor): Input tensor (B, C, T). |
|
start_idxs (Tensor): Start index tensor (B,). |
|
segment_size (int): Segment size. |
|
|
|
Returns: |
|
Tensor: Segmented tensor (B, C, segment_size). |
|
|
|
""" |
|
b, c, t = x.size() |
|
segments = x.new_zeros(b, c, segment_size) |
|
for i, start_idx in enumerate(start_idxs): |
|
segments[i] = x[i, :, start_idx : start_idx + segment_size] |
|
return segments |
|
|