diffsingerkr / Modules /Modules.py
codejin's picture
initial commit
67d041f
from argparse import Namespace
import torch
import math
from typing import Union
from .Layer import Conv1d, LayerNorm, LinearAttention
from .Diffusion import Diffusion
class DiffSinger(torch.nn.Module):
def __init__(self, hyper_parameters: Namespace):
super().__init__()
self.hp = hyper_parameters
self.encoder = Encoder(self.hp)
self.diffusion = Diffusion(self.hp)
def forward(
self,
tokens: torch.LongTensor,
notes: torch.LongTensor,
durations: torch.LongTensor,
lengths: torch.LongTensor,
genres: torch.LongTensor,
singers: torch.LongTensor,
features: Union[torch.FloatTensor, None]= None,
ddim_steps: Union[int, None]= None
):
encodings, linear_predictions = self.encoder(
tokens= tokens,
notes= notes,
durations= durations,
lengths= lengths,
genres= genres,
singers= singers
) # [Batch, Enc_d, Feature_t]
encodings = torch.cat([encodings, linear_predictions], dim= 1) # [Batch, Enc_d + Feature_d, Feature_t]
if not features is None or ddim_steps is None or ddim_steps == self.hp.Diffusion.Max_Step:
diffusion_predictions, noises, epsilons = self.diffusion(
encodings= encodings,
features= features,
)
else:
noises, epsilons = None, None
diffusion_predictions = self.diffusion.DDIM(
encodings= encodings,
ddim_steps= ddim_steps
)
return linear_predictions, diffusion_predictions, noises, epsilons
class Encoder(torch.nn.Module):
def __init__(
self,
hyper_parameters: Namespace
):
super().__init__()
self.hp = hyper_parameters
if self.hp.Feature_Type == 'Mel':
self.feature_size = self.hp.Sound.Mel_Dim
elif self.hp.Feature_Type == 'Spectrogram':
self.feature_size = self.hp.Sound.N_FFT // 2 + 1
self.token_embedding = torch.nn.Embedding(
num_embeddings= self.hp.Tokens,
embedding_dim= self.hp.Encoder.Size
)
self.note_embedding = torch.nn.Embedding(
num_embeddings= self.hp.Notes,
embedding_dim= self.hp.Encoder.Size
)
self.duration_embedding = Duration_Positional_Encoding(
num_embeddings= self.hp.Durations,
embedding_dim= self.hp.Encoder.Size
)
self.genre_embedding = torch.nn.Embedding(
num_embeddings= self.hp.Genres,
embedding_dim= self.hp.Encoder.Size,
)
self.singer_embedding = torch.nn.Embedding(
num_embeddings= self.hp.Singers,
embedding_dim= self.hp.Encoder.Size,
)
torch.nn.init.xavier_uniform_(self.token_embedding.weight)
torch.nn.init.xavier_uniform_(self.note_embedding.weight)
torch.nn.init.xavier_uniform_(self.genre_embedding.weight)
torch.nn.init.xavier_uniform_(self.singer_embedding.weight)
self.fft_blocks = torch.nn.ModuleList([
FFT_Block(
channels= self.hp.Encoder.Size,
num_head= self.hp.Encoder.ConvFFT.Head,
ffn_kernel_size= self.hp.Encoder.ConvFFT.FFN.Kernel_Size,
dropout_rate= self.hp.Encoder.ConvFFT.Dropout_Rate
)
for _ in range(self.hp.Encoder.ConvFFT.Stack)
])
self.linear_projection = Conv1d(
in_channels= self.hp.Encoder.Size,
out_channels= self.feature_size,
kernel_size= 1,
bias= True,
w_init_gain= 'linear'
)
def forward(
self,
tokens: torch.Tensor,
notes: torch.Tensor,
durations: torch.Tensor,
lengths: torch.Tensor,
genres: torch.Tensor,
singers: torch.Tensor
):
x = \
self.token_embedding(tokens) + \
self.note_embedding(notes) + \
self.duration_embedding(durations) + \
self.genre_embedding(genres).unsqueeze(1) + \
self.singer_embedding(singers).unsqueeze(1)
x = x.permute(0, 2, 1) # [Batch, Enc_d, Enc_t]
for block in self.fft_blocks:
x = block(x, lengths) # [Batch, Enc_d, Enc_t]
linear_predictions = self.linear_projection(x) # [Batch, Feature_d, Enc_t]
return x, linear_predictions
class FFT_Block(torch.nn.Module):
def __init__(
self,
channels: int,
num_head: int,
ffn_kernel_size: int,
dropout_rate: float= 0.1,
) -> None:
super().__init__()
self.attention = LinearAttention(
channels= channels,
calc_channels= channels,
num_heads= num_head,
dropout_rate= dropout_rate
)
self.ffn = FFN(
channels= channels,
kernel_size= ffn_kernel_size,
dropout_rate= dropout_rate
)
def forward(
self,
x: torch.Tensor,
lengths: torch.Tensor
) -> torch.Tensor:
'''
x: [Batch, Dim, Time]
'''
masks = (~Mask_Generate(lengths= lengths, max_length= torch.ones_like(x[0, 0]).sum())).unsqueeze(1).float() # float mask
# Attention + Dropout + LayerNorm
x = self.attention(x)
# FFN + Dropout + LayerNorm
x = self.ffn(x, masks)
return x * masks
class FFN(torch.nn.Module):
def __init__(
self,
channels: int,
kernel_size: int,
dropout_rate: float= 0.1,
) -> None:
super().__init__()
self.conv_0 = Conv1d(
in_channels= channels,
out_channels= channels,
kernel_size= kernel_size,
padding= (kernel_size - 1) // 2,
w_init_gain= 'relu'
)
self.relu = torch.nn.ReLU()
self.dropout = torch.nn.Dropout(p= dropout_rate)
self.conv_1 = Conv1d(
in_channels= channels,
out_channels= channels,
kernel_size= kernel_size,
padding= (kernel_size - 1) // 2,
w_init_gain= 'linear'
)
self.norm = LayerNorm(
num_features= channels,
)
def forward(
self,
x: torch.Tensor,
masks: torch.Tensor
) -> torch.Tensor:
'''
x: [Batch, Dim, Time]
'''
residuals = x
x = self.conv_0(x * masks)
x = self.relu(x)
x = self.dropout(x)
x = self.conv_1(x * masks)
x = self.dropout(x)
x = self.norm(x + residuals)
return x * masks
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
class Duration_Positional_Encoding(torch.nn.Embedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
):
positional_embedding = torch.zeros(num_embeddings, embedding_dim)
position = torch.arange(0, num_embeddings, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
positional_embedding[:, 0::2] = torch.sin(position * div_term)
positional_embedding[:, 1::2] = torch.cos(position * div_term)
super().__init__(
num_embeddings= num_embeddings,
embedding_dim= embedding_dim,
_weight= positional_embedding
)
self.weight.requires_grad = False
self.alpha = torch.nn.Parameter(
data= torch.ones(1) * 0.01,
requires_grad= True
)
def forward(self, durations):
'''
durations: [Batch, Length]
'''
return self.alpha * super().forward(durations) # [Batch, Dim, Length]
@torch.jit.script
def get_pe(x: torch.Tensor, pe: torch.Tensor):
pe = pe.repeat(1, 1, math.ceil(x.size(2) / pe.size(2)))
return pe[:, :, :x.size(2)]
def Mask_Generate(lengths: torch.Tensor, max_length: Union[torch.Tensor, int, None]= None):
'''
lengths: [Batch]
max_lengths: an int value. If None, max_lengths == max(lengths)
'''
max_length = max_length or torch.max(lengths)
sequence = torch.arange(max_length)[None, :].to(lengths.device)
return sequence >= lengths[:, None] # [Batch, Time]