Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Copyright 2020 Tomoki Hayashi | |
# MIT License (https://opensource.org/licenses/MIT) | |
"""MelGAN Modules.""" | |
import logging | |
import numpy as np | |
import torch | |
from modules.parallel_wavegan.layers import CausalConv1d | |
from modules.parallel_wavegan.layers import CausalConvTranspose1d | |
from modules.parallel_wavegan.layers import ResidualStack | |
class MelGANGenerator(torch.nn.Module): | |
"""MelGAN generator module.""" | |
def __init__(self, | |
in_channels=80, | |
out_channels=1, | |
kernel_size=7, | |
channels=512, | |
bias=True, | |
upsample_scales=[8, 8, 2, 2], | |
stack_kernel_size=3, | |
stacks=3, | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"negative_slope": 0.2}, | |
pad="ReflectionPad1d", | |
pad_params={}, | |
use_final_nonlinear_activation=True, | |
use_weight_norm=True, | |
use_causal_conv=False, | |
): | |
"""Initialize MelGANGenerator module. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
kernel_size (int): Kernel size of initial and final conv layer. | |
channels (int): Initial number of channels for conv layer. | |
bias (bool): Whether to add bias parameter in convolution layers. | |
upsample_scales (list): List of upsampling scales. | |
stack_kernel_size (int): Kernel size of dilated conv layers in residual stack. | |
stacks (int): Number of stacks in a single residual stack. | |
nonlinear_activation (str): Activation function module name. | |
nonlinear_activation_params (dict): Hyperparameters for activation function. | |
pad (str): Padding function module name before dilated convolution layer. | |
pad_params (dict): Hyperparameters for padding function. | |
use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer. | |
use_weight_norm (bool): Whether to use weight norm. | |
If set to true, it will be applied to all of the conv layers. | |
use_causal_conv (bool): Whether to use causal convolution. | |
""" | |
super(MelGANGenerator, self).__init__() | |
# check hyper parameters is valid | |
assert channels >= np.prod(upsample_scales) | |
assert channels % (2 ** len(upsample_scales)) == 0 | |
if not use_causal_conv: | |
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." | |
# add initial layer | |
layers = [] | |
if not use_causal_conv: | |
layers += [ | |
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params), | |
torch.nn.Conv1d(in_channels, channels, kernel_size, bias=bias), | |
] | |
else: | |
layers += [ | |
CausalConv1d(in_channels, channels, kernel_size, | |
bias=bias, pad=pad, pad_params=pad_params), | |
] | |
for i, upsample_scale in enumerate(upsample_scales): | |
# add upsampling layer | |
layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)] | |
if not use_causal_conv: | |
layers += [ | |
torch.nn.ConvTranspose1d( | |
channels // (2 ** i), | |
channels // (2 ** (i + 1)), | |
upsample_scale * 2, | |
stride=upsample_scale, | |
padding=upsample_scale // 2 + upsample_scale % 2, | |
output_padding=upsample_scale % 2, | |
bias=bias, | |
) | |
] | |
else: | |
layers += [ | |
CausalConvTranspose1d( | |
channels // (2 ** i), | |
channels // (2 ** (i + 1)), | |
upsample_scale * 2, | |
stride=upsample_scale, | |
bias=bias, | |
) | |
] | |
# add residual stack | |
for j in range(stacks): | |
layers += [ | |
ResidualStack( | |
kernel_size=stack_kernel_size, | |
channels=channels // (2 ** (i + 1)), | |
dilation=stack_kernel_size ** j, | |
bias=bias, | |
nonlinear_activation=nonlinear_activation, | |
nonlinear_activation_params=nonlinear_activation_params, | |
pad=pad, | |
pad_params=pad_params, | |
use_causal_conv=use_causal_conv, | |
) | |
] | |
# add final layer | |
layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)] | |
if not use_causal_conv: | |
layers += [ | |
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params), | |
torch.nn.Conv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, bias=bias), | |
] | |
else: | |
layers += [ | |
CausalConv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, | |
bias=bias, pad=pad, pad_params=pad_params), | |
] | |
if use_final_nonlinear_activation: | |
layers += [torch.nn.Tanh()] | |
# define the model as a single function | |
self.melgan = torch.nn.Sequential(*layers) | |
# apply weight norm | |
if use_weight_norm: | |
self.apply_weight_norm() | |
# reset parameters | |
self.reset_parameters() | |
def forward(self, c): | |
"""Calculate forward propagation. | |
Args: | |
c (Tensor): Input tensor (B, channels, T). | |
Returns: | |
Tensor: Output tensor (B, 1, T ** prod(upsample_scales)). | |
""" | |
return self.melgan(c) | |
def remove_weight_norm(self): | |
"""Remove weight normalization module from all of the layers.""" | |
def _remove_weight_norm(m): | |
try: | |
logging.debug(f"Weight norm is removed from {m}.") | |
torch.nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
self.apply(_remove_weight_norm) | |
def apply_weight_norm(self): | |
"""Apply weight normalization module from all of the layers.""" | |
def _apply_weight_norm(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): | |
torch.nn.utils.weight_norm(m) | |
logging.debug(f"Weight norm is applied to {m}.") | |
self.apply(_apply_weight_norm) | |
def reset_parameters(self): | |
"""Reset parameters. | |
This initialization follows official implementation manner. | |
https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py | |
""" | |
def _reset_parameters(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): | |
m.weight.data.normal_(0.0, 0.02) | |
logging.debug(f"Reset parameters in {m}.") | |
self.apply(_reset_parameters) | |
class MelGANDiscriminator(torch.nn.Module): | |
"""MelGAN discriminator module.""" | |
def __init__(self, | |
in_channels=1, | |
out_channels=1, | |
kernel_sizes=[5, 3], | |
channels=16, | |
max_downsample_channels=1024, | |
bias=True, | |
downsample_scales=[4, 4, 4, 4], | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"negative_slope": 0.2}, | |
pad="ReflectionPad1d", | |
pad_params={}, | |
): | |
"""Initilize MelGAN discriminator module. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer, | |
and the first and the second kernel sizes will be used for the last two layers. | |
For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15, | |
the last two layers' kernel size will be 5 and 3, respectively. | |
channels (int): Initial number of channels for conv layer. | |
max_downsample_channels (int): Maximum number of channels for downsampling layers. | |
bias (bool): Whether to add bias parameter in convolution layers. | |
downsample_scales (list): List of downsampling scales. | |
nonlinear_activation (str): Activation function module name. | |
nonlinear_activation_params (dict): Hyperparameters for activation function. | |
pad (str): Padding function module name before dilated convolution layer. | |
pad_params (dict): Hyperparameters for padding function. | |
""" | |
super(MelGANDiscriminator, self).__init__() | |
self.layers = torch.nn.ModuleList() | |
# check kernel size is valid | |
assert len(kernel_sizes) == 2 | |
assert kernel_sizes[0] % 2 == 1 | |
assert kernel_sizes[1] % 2 == 1 | |
# add first layer | |
self.layers += [ | |
torch.nn.Sequential( | |
getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params), | |
torch.nn.Conv1d(in_channels, channels, np.prod(kernel_sizes), bias=bias), | |
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), | |
) | |
] | |
# add downsample layers | |
in_chs = channels | |
for downsample_scale in downsample_scales: | |
out_chs = min(in_chs * downsample_scale, max_downsample_channels) | |
self.layers += [ | |
torch.nn.Sequential( | |
torch.nn.Conv1d( | |
in_chs, out_chs, | |
kernel_size=downsample_scale * 10 + 1, | |
stride=downsample_scale, | |
padding=downsample_scale * 5, | |
groups=in_chs // 4, | |
bias=bias, | |
), | |
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), | |
) | |
] | |
in_chs = out_chs | |
# add final layers | |
out_chs = min(in_chs * 2, max_downsample_channels) | |
self.layers += [ | |
torch.nn.Sequential( | |
torch.nn.Conv1d( | |
in_chs, out_chs, kernel_sizes[0], | |
padding=(kernel_sizes[0] - 1) // 2, | |
bias=bias, | |
), | |
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), | |
) | |
] | |
self.layers += [ | |
torch.nn.Conv1d( | |
out_chs, out_channels, kernel_sizes[1], | |
padding=(kernel_sizes[1] - 1) // 2, | |
bias=bias, | |
), | |
] | |
def forward(self, x): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Input noise signal (B, 1, T). | |
Returns: | |
List: List of output tensors of each layer. | |
""" | |
outs = [] | |
for f in self.layers: | |
x = f(x) | |
outs += [x] | |
return outs | |
class MelGANMultiScaleDiscriminator(torch.nn.Module): | |
"""MelGAN multi-scale discriminator module.""" | |
def __init__(self, | |
in_channels=1, | |
out_channels=1, | |
scales=3, | |
downsample_pooling="AvgPool1d", | |
# follow the official implementation setting | |
downsample_pooling_params={ | |
"kernel_size": 4, | |
"stride": 2, | |
"padding": 1, | |
"count_include_pad": False, | |
}, | |
kernel_sizes=[5, 3], | |
channels=16, | |
max_downsample_channels=1024, | |
bias=True, | |
downsample_scales=[4, 4, 4, 4], | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"negative_slope": 0.2}, | |
pad="ReflectionPad1d", | |
pad_params={}, | |
use_weight_norm=True, | |
): | |
"""Initilize MelGAN multi-scale discriminator module. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
downsample_pooling (str): Pooling module name for downsampling of the inputs. | |
downsample_pooling_params (dict): Parameters for the above pooling module. | |
kernel_sizes (list): List of two kernel sizes. The sum will be used for the first conv layer, | |
and the first and the second kernel sizes will be used for the last two layers. | |
channels (int): Initial number of channels for conv layer. | |
max_downsample_channels (int): Maximum number of channels for downsampling layers. | |
bias (bool): Whether to add bias parameter in convolution layers. | |
downsample_scales (list): List of downsampling scales. | |
nonlinear_activation (str): Activation function module name. | |
nonlinear_activation_params (dict): Hyperparameters for activation function. | |
pad (str): Padding function module name before dilated convolution layer. | |
pad_params (dict): Hyperparameters for padding function. | |
use_causal_conv (bool): Whether to use causal convolution. | |
""" | |
super(MelGANMultiScaleDiscriminator, self).__init__() | |
self.discriminators = torch.nn.ModuleList() | |
# add discriminators | |
for _ in range(scales): | |
self.discriminators += [ | |
MelGANDiscriminator( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_sizes=kernel_sizes, | |
channels=channels, | |
max_downsample_channels=max_downsample_channels, | |
bias=bias, | |
downsample_scales=downsample_scales, | |
nonlinear_activation=nonlinear_activation, | |
nonlinear_activation_params=nonlinear_activation_params, | |
pad=pad, | |
pad_params=pad_params, | |
) | |
] | |
self.pooling = getattr(torch.nn, downsample_pooling)(**downsample_pooling_params) | |
# apply weight norm | |
if use_weight_norm: | |
self.apply_weight_norm() | |
# reset parameters | |
self.reset_parameters() | |
def forward(self, x): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Input noise signal (B, 1, T). | |
Returns: | |
List: List of list of each discriminator outputs, which consists of each layer output tensors. | |
""" | |
outs = [] | |
for f in self.discriminators: | |
outs += [f(x)] | |
x = self.pooling(x) | |
return outs | |
def remove_weight_norm(self): | |
"""Remove weight normalization module from all of the layers.""" | |
def _remove_weight_norm(m): | |
try: | |
logging.debug(f"Weight norm is removed from {m}.") | |
torch.nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
self.apply(_remove_weight_norm) | |
def apply_weight_norm(self): | |
"""Apply weight normalization module from all of the layers.""" | |
def _apply_weight_norm(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): | |
torch.nn.utils.weight_norm(m) | |
logging.debug(f"Weight norm is applied to {m}.") | |
self.apply(_apply_weight_norm) | |
def reset_parameters(self): | |
"""Reset parameters. | |
This initialization follows official implementation manner. | |
https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py | |
""" | |
def _reset_parameters(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): | |
m.weight.data.normal_(0.0, 0.02) | |
logging.debug(f"Reset parameters in {m}.") | |
self.apply(_reset_parameters) | |