Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Copyright 2019 Tomoki Hayashi | |
# MIT License (https://opensource.org/licenses/MIT) | |
"""Parallel WaveGAN Modules.""" | |
import logging | |
import math | |
import torch | |
from torch import nn | |
from modules.parallel_wavegan.layers import Conv1d | |
from modules.parallel_wavegan.layers import Conv1d1x1 | |
from modules.parallel_wavegan.layers import ResidualBlock | |
from modules.parallel_wavegan.layers import upsample | |
from modules.parallel_wavegan import models | |
class ParallelWaveGANGenerator(torch.nn.Module): | |
"""Parallel WaveGAN Generator module.""" | |
def __init__(self, | |
in_channels=1, | |
out_channels=1, | |
kernel_size=3, | |
layers=30, | |
stacks=3, | |
residual_channels=64, | |
gate_channels=128, | |
skip_channels=64, | |
aux_channels=80, | |
aux_context_window=2, | |
dropout=0.0, | |
bias=True, | |
use_weight_norm=True, | |
use_causal_conv=False, | |
upsample_conditional_features=True, | |
upsample_net="ConvInUpsampleNetwork", | |
upsample_params={"upsample_scales": [4, 4, 4, 4]}, | |
use_pitch_embed=False, | |
): | |
"""Initialize Parallel WaveGAN Generator module. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
kernel_size (int): Kernel size of dilated convolution. | |
layers (int): Number of residual block layers. | |
stacks (int): Number of stacks i.e., dilation cycles. | |
residual_channels (int): Number of channels in residual conv. | |
gate_channels (int): Number of channels in gated conv. | |
skip_channels (int): Number of channels in skip conv. | |
aux_channels (int): Number of channels for auxiliary feature conv. | |
aux_context_window (int): Context window size for auxiliary feature. | |
dropout (float): Dropout rate. 0.0 means no dropout applied. | |
bias (bool): Whether to use bias parameter in conv layer. | |
use_weight_norm (bool): Whether to use weight norm. | |
If set to true, it will be applied to all of the conv layers. | |
use_causal_conv (bool): Whether to use causal structure. | |
upsample_conditional_features (bool): Whether to use upsampling network. | |
upsample_net (str): Upsampling network architecture. | |
upsample_params (dict): Upsampling network parameters. | |
""" | |
super(ParallelWaveGANGenerator, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.aux_channels = aux_channels | |
self.layers = layers | |
self.stacks = stacks | |
self.kernel_size = kernel_size | |
# check the number of layers and stacks | |
assert layers % stacks == 0 | |
layers_per_stack = layers // stacks | |
# define first convolution | |
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) | |
# define conv + upsampling network | |
if upsample_conditional_features: | |
upsample_params.update({ | |
"use_causal_conv": use_causal_conv, | |
}) | |
if upsample_net == "MelGANGenerator": | |
assert aux_context_window == 0 | |
upsample_params.update({ | |
"use_weight_norm": False, # not to apply twice | |
"use_final_nonlinear_activation": False, | |
}) | |
self.upsample_net = getattr(models, upsample_net)(**upsample_params) | |
else: | |
if upsample_net == "ConvInUpsampleNetwork": | |
upsample_params.update({ | |
"aux_channels": aux_channels, | |
"aux_context_window": aux_context_window, | |
}) | |
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params) | |
else: | |
self.upsample_net = None | |
# define residual blocks | |
self.conv_layers = torch.nn.ModuleList() | |
for layer in range(layers): | |
dilation = 2 ** (layer % layers_per_stack) | |
conv = ResidualBlock( | |
kernel_size=kernel_size, | |
residual_channels=residual_channels, | |
gate_channels=gate_channels, | |
skip_channels=skip_channels, | |
aux_channels=aux_channels, | |
dilation=dilation, | |
dropout=dropout, | |
bias=bias, | |
use_causal_conv=use_causal_conv, | |
) | |
self.conv_layers += [conv] | |
# define output layers | |
self.last_conv_layers = torch.nn.ModuleList([ | |
torch.nn.ReLU(inplace=True), | |
Conv1d1x1(skip_channels, skip_channels, bias=True), | |
torch.nn.ReLU(inplace=True), | |
Conv1d1x1(skip_channels, out_channels, bias=True), | |
]) | |
self.use_pitch_embed = use_pitch_embed | |
if use_pitch_embed: | |
self.pitch_embed = nn.Embedding(300, aux_channels, 0) | |
self.c_proj = nn.Linear(2 * aux_channels, aux_channels) | |
# apply weight norm | |
if use_weight_norm: | |
self.apply_weight_norm() | |
def forward(self, x, c=None, pitch=None, **kwargs): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Input noise signal (B, C_in, T). | |
c (Tensor): Local conditioning auxiliary features (B, C ,T'). | |
pitch (Tensor): Local conditioning pitch (B, T'). | |
Returns: | |
Tensor: Output tensor (B, C_out, T) | |
""" | |
# perform upsampling | |
if c is not None and self.upsample_net is not None: | |
if self.use_pitch_embed: | |
p = self.pitch_embed(pitch) | |
c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2) | |
c = self.upsample_net(c) | |
assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1)) | |
# encode to hidden representation | |
x = self.first_conv(x) | |
skips = 0 | |
for f in self.conv_layers: | |
x, h = f(x, c) | |
skips += h | |
skips *= math.sqrt(1.0 / len(self.conv_layers)) | |
# apply final layers | |
x = skips | |
for f in self.last_conv_layers: | |
x = f(x) | |
return x | |
def remove_weight_norm(self): | |
"""Remove weight normalization module from all of the layers.""" | |
def _remove_weight_norm(m): | |
try: | |
logging.debug(f"Weight norm is removed from {m}.") | |
torch.nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
self.apply(_remove_weight_norm) | |
def apply_weight_norm(self): | |
"""Apply weight normalization module from all of the layers.""" | |
def _apply_weight_norm(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): | |
torch.nn.utils.weight_norm(m) | |
logging.debug(f"Weight norm is applied to {m}.") | |
self.apply(_apply_weight_norm) | |
def _get_receptive_field_size(layers, stacks, kernel_size, | |
dilation=lambda x: 2 ** x): | |
assert layers % stacks == 0 | |
layers_per_cycle = layers // stacks | |
dilations = [dilation(i % layers_per_cycle) for i in range(layers)] | |
return (kernel_size - 1) * sum(dilations) + 1 | |
def receptive_field_size(self): | |
"""Return receptive field size.""" | |
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) | |
class ParallelWaveGANDiscriminator(torch.nn.Module): | |
"""Parallel WaveGAN Discriminator module.""" | |
def __init__(self, | |
in_channels=1, | |
out_channels=1, | |
kernel_size=3, | |
layers=10, | |
conv_channels=64, | |
dilation_factor=1, | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"negative_slope": 0.2}, | |
bias=True, | |
use_weight_norm=True, | |
): | |
"""Initialize Parallel WaveGAN Discriminator module. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
kernel_size (int): Number of output channels. | |
layers (int): Number of conv layers. | |
conv_channels (int): Number of chnn layers. | |
dilation_factor (int): Dilation factor. For example, if dilation_factor = 2, | |
the dilation will be 2, 4, 8, ..., and so on. | |
nonlinear_activation (str): Nonlinear function after each conv. | |
nonlinear_activation_params (dict): Nonlinear function parameters | |
bias (bool): Whether to use bias parameter in conv. | |
use_weight_norm (bool) Whether to use weight norm. | |
If set to true, it will be applied to all of the conv layers. | |
""" | |
super(ParallelWaveGANDiscriminator, self).__init__() | |
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." | |
assert dilation_factor > 0, "Dilation factor must be > 0." | |
self.conv_layers = torch.nn.ModuleList() | |
conv_in_channels = in_channels | |
for i in range(layers - 1): | |
if i == 0: | |
dilation = 1 | |
else: | |
dilation = i if dilation_factor == 1 else dilation_factor ** i | |
conv_in_channels = conv_channels | |
padding = (kernel_size - 1) // 2 * dilation | |
conv_layer = [ | |
Conv1d(conv_in_channels, conv_channels, | |
kernel_size=kernel_size, padding=padding, | |
dilation=dilation, bias=bias), | |
getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params) | |
] | |
self.conv_layers += conv_layer | |
padding = (kernel_size - 1) // 2 | |
last_conv_layer = Conv1d( | |
conv_in_channels, out_channels, | |
kernel_size=kernel_size, padding=padding, bias=bias) | |
self.conv_layers += [last_conv_layer] | |
# apply weight norm | |
if use_weight_norm: | |
self.apply_weight_norm() | |
def forward(self, x): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Input noise signal (B, 1, T). | |
Returns: | |
Tensor: Output tensor (B, 1, T) | |
""" | |
for f in self.conv_layers: | |
x = f(x) | |
return x | |
def apply_weight_norm(self): | |
"""Apply weight normalization module from all of the layers.""" | |
def _apply_weight_norm(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): | |
torch.nn.utils.weight_norm(m) | |
logging.debug(f"Weight norm is applied to {m}.") | |
self.apply(_apply_weight_norm) | |
def remove_weight_norm(self): | |
"""Remove weight normalization module from all of the layers.""" | |
def _remove_weight_norm(m): | |
try: | |
logging.debug(f"Weight norm is removed from {m}.") | |
torch.nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
self.apply(_remove_weight_norm) | |
class ResidualParallelWaveGANDiscriminator(torch.nn.Module): | |
"""Parallel WaveGAN Discriminator module.""" | |
def __init__(self, | |
in_channels=1, | |
out_channels=1, | |
kernel_size=3, | |
layers=30, | |
stacks=3, | |
residual_channels=64, | |
gate_channels=128, | |
skip_channels=64, | |
dropout=0.0, | |
bias=True, | |
use_weight_norm=True, | |
use_causal_conv=False, | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"negative_slope": 0.2}, | |
): | |
"""Initialize Parallel WaveGAN Discriminator module. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
kernel_size (int): Kernel size of dilated convolution. | |
layers (int): Number of residual block layers. | |
stacks (int): Number of stacks i.e., dilation cycles. | |
residual_channels (int): Number of channels in residual conv. | |
gate_channels (int): Number of channels in gated conv. | |
skip_channels (int): Number of channels in skip conv. | |
dropout (float): Dropout rate. 0.0 means no dropout applied. | |
bias (bool): Whether to use bias parameter in conv. | |
use_weight_norm (bool): Whether to use weight norm. | |
If set to true, it will be applied to all of the conv layers. | |
use_causal_conv (bool): Whether to use causal structure. | |
nonlinear_activation_params (dict): Nonlinear function parameters | |
""" | |
super(ResidualParallelWaveGANDiscriminator, self).__init__() | |
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.layers = layers | |
self.stacks = stacks | |
self.kernel_size = kernel_size | |
# check the number of layers and stacks | |
assert layers % stacks == 0 | |
layers_per_stack = layers // stacks | |
# define first convolution | |
self.first_conv = torch.nn.Sequential( | |
Conv1d1x1(in_channels, residual_channels, bias=True), | |
getattr(torch.nn, nonlinear_activation)( | |
inplace=True, **nonlinear_activation_params), | |
) | |
# define residual blocks | |
self.conv_layers = torch.nn.ModuleList() | |
for layer in range(layers): | |
dilation = 2 ** (layer % layers_per_stack) | |
conv = ResidualBlock( | |
kernel_size=kernel_size, | |
residual_channels=residual_channels, | |
gate_channels=gate_channels, | |
skip_channels=skip_channels, | |
aux_channels=-1, | |
dilation=dilation, | |
dropout=dropout, | |
bias=bias, | |
use_causal_conv=use_causal_conv, | |
) | |
self.conv_layers += [conv] | |
# define output layers | |
self.last_conv_layers = torch.nn.ModuleList([ | |
getattr(torch.nn, nonlinear_activation)( | |
inplace=True, **nonlinear_activation_params), | |
Conv1d1x1(skip_channels, skip_channels, bias=True), | |
getattr(torch.nn, nonlinear_activation)( | |
inplace=True, **nonlinear_activation_params), | |
Conv1d1x1(skip_channels, out_channels, bias=True), | |
]) | |
# apply weight norm | |
if use_weight_norm: | |
self.apply_weight_norm() | |
def forward(self, x): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Input noise signal (B, 1, T). | |
Returns: | |
Tensor: Output tensor (B, 1, T) | |
""" | |
x = self.first_conv(x) | |
skips = 0 | |
for f in self.conv_layers: | |
x, h = f(x, None) | |
skips += h | |
skips *= math.sqrt(1.0 / len(self.conv_layers)) | |
# apply final layers | |
x = skips | |
for f in self.last_conv_layers: | |
x = f(x) | |
return x | |
def apply_weight_norm(self): | |
"""Apply weight normalization module from all of the layers.""" | |
def _apply_weight_norm(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): | |
torch.nn.utils.weight_norm(m) | |
logging.debug(f"Weight norm is applied to {m}.") | |
self.apply(_apply_weight_norm) | |
def remove_weight_norm(self): | |
"""Remove weight normalization module from all of the layers.""" | |
def _remove_weight_norm(m): | |
try: | |
logging.debug(f"Weight norm is removed from {m}.") | |
torch.nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
self.apply(_remove_weight_norm) | |