File size: 4,425 Bytes
4c9fe71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# -*- coding: utf-8 -*-
"""Residual block module in WaveNet.
This code is modified from https://github.com/r9y9/wavenet_vocoder.
"""
import math
import torch
import torch.nn.functional as F
class Conv1d(torch.nn.Conv1d):
"""Conv1d module with customized initialization."""
def __init__(self, *args, **kwargs):
"""Initialize Conv1d module."""
super(Conv1d, self).__init__(*args, **kwargs)
def reset_parameters(self):
"""Reset parameters."""
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
if self.bias is not None:
torch.nn.init.constant_(self.bias, 0.0)
class Conv1d1x1(Conv1d):
"""1x1 Conv1d with customized initialization."""
def __init__(self, in_channels, out_channels, bias):
"""Initialize 1x1 Conv1d module."""
super(Conv1d1x1, self).__init__(in_channels, out_channels,
kernel_size=1, padding=0,
dilation=1, bias=bias)
class ResidualBlock(torch.nn.Module):
"""Residual block module in WaveNet."""
def __init__(self,
kernel_size=3,
residual_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
dilation=1,
bias=True,
use_causal_conv=False
):
"""Initialize ResidualBlock module.
Args:
kernel_size (int): Kernel size of dilation convolution layer.
residual_channels (int): Number of channels for residual connection.
skip_channels (int): Number of channels for skip connection.
aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.
dropout (float): Dropout probability.
dilation (int): Dilation factor.
bias (bool): Whether to add bias parameter in convolution layers.
use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution.
"""
super(ResidualBlock, self).__init__()
self.dropout = dropout
# no future time stamps available
if use_causal_conv:
padding = (kernel_size - 1) * dilation
else:
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
padding = (kernel_size - 1) // 2 * dilation
self.use_causal_conv = use_causal_conv
# dilation conv
self.conv = Conv1d(residual_channels, gate_channels, kernel_size,
padding=padding, dilation=dilation, bias=bias)
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
else:
self.conv1x1_aux = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias)
def forward(self, x, c):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, residual_channels, T).
c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).
Returns:
Tensor: Output tensor for residual connection (B, residual_channels, T).
Tensor: Output tensor for skip connection (B, skip_channels, T).
"""
residual = x
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv(x)
# remove future time steps if use_causal_conv conv
x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x
# split into two part for gated activation
splitdim = 1
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
# local conditioning
if c is not None:
assert self.conv1x1_aux is not None
c = self.conv1x1_aux(c)
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ca, xb + cb
x = torch.tanh(xa) * torch.sigmoid(xb)
# for skip connection
s = self.conv1x1_skip(x)
# for residual connection
x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5)
return x, s
|