Spaces:
Runtime error
Runtime error
""" | |
StarGAN v2 | |
Copyright (c) 2020-present NAVER Corp. | |
This work is licensed under the Creative Commons Attribution-NonCommercial | |
4.0 International License. To view a copy of this license, visit | |
http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
""" | |
import os | |
import os.path as osp | |
import copy | |
import math | |
from munch import Munch | |
import numpy as np | |
import paddle | |
import paddle.nn as nn | |
import paddle.nn.functional as F | |
class DownSample(nn.Layer): | |
def __init__(self, layer_type): | |
super().__init__() | |
self.layer_type = layer_type | |
def forward(self, x): | |
if self.layer_type == 'none': | |
return x | |
elif self.layer_type == 'timepreserve': | |
return F.avg_pool2d(x, (2, 1)) | |
elif self.layer_type == 'half': | |
return F.avg_pool2d(x, 2) | |
else: | |
raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type) | |
class UpSample(nn.Layer): | |
def __init__(self, layer_type): | |
super().__init__() | |
self.layer_type = layer_type | |
def forward(self, x): | |
if self.layer_type == 'none': | |
return x | |
elif self.layer_type == 'timepreserve': | |
return F.interpolate(x, scale_factor=(2, 1), mode='nearest') | |
elif self.layer_type == 'half': | |
return F.interpolate(x, scale_factor=2, mode='nearest') | |
else: | |
raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type) | |
class ResBlk(nn.Layer): | |
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), | |
normalize=False, downsample='none'): | |
super().__init__() | |
self.actv = actv | |
self.normalize = normalize | |
self.downsample = DownSample(downsample) | |
self.learned_sc = dim_in != dim_out | |
self._build_weights(dim_in, dim_out) | |
def _build_weights(self, dim_in, dim_out): | |
self.conv1 = nn.Conv2D(dim_in, dim_in, 3, 1, 1) | |
self.conv2 = nn.Conv2D(dim_in, dim_out, 3, 1, 1) | |
if self.normalize: | |
self.norm1 = nn.InstanceNorm2D(dim_in) | |
self.norm2 = nn.InstanceNorm2D(dim_in) | |
if self.learned_sc: | |
self.conv1x1 = nn.Conv2D(dim_in, dim_out, 1, 1, 0, bias_attr=False) | |
def _shortcut(self, x): | |
if self.learned_sc: | |
x = self.conv1x1(x) | |
if self.downsample: | |
x = self.downsample(x) | |
return x | |
def _residual(self, x): | |
if self.normalize: | |
x = self.norm1(x) | |
x = self.actv(x) | |
x = self.conv1(x) | |
x = self.downsample(x) | |
if self.normalize: | |
x = self.norm2(x) | |
x = self.actv(x) | |
x = self.conv2(x) | |
return x | |
def forward(self, x): | |
x = self._shortcut(x) + self._residual(x) | |
return x / math.sqrt(2) # unit variance | |
class AdaIN(nn.Layer): | |
def __init__(self, style_dim, num_features): | |
super().__init__() | |
self.norm = nn.InstanceNorm2D(num_features, weight_attr=False, bias_attr=False) | |
self.fc = nn.Linear(style_dim, num_features*2) | |
def forward(self, x, s): | |
if len(s.shape) == 1: | |
s = s[None] | |
h = self.fc(s) | |
h = h.reshape((h.shape[0], h.shape[1], 1, 1)) | |
gamma, beta = paddle.split(h, 2, axis=1) | |
return (1 + gamma) * self.norm(x) + beta | |
class AdainResBlk(nn.Layer): | |
def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=0, | |
actv=nn.LeakyReLU(0.2), upsample='none'): | |
super().__init__() | |
self.w_hpf = w_hpf | |
self.actv = actv | |
self.upsample = UpSample(upsample) | |
self.learned_sc = dim_in != dim_out | |
self._build_weights(dim_in, dim_out, style_dim) | |
def _build_weights(self, dim_in, dim_out, style_dim=64): | |
self.conv1 = nn.Conv2D(dim_in, dim_out, 3, 1, 1) | |
self.conv2 = nn.Conv2D(dim_out, dim_out, 3, 1, 1) | |
self.norm1 = AdaIN(style_dim, dim_in) | |
self.norm2 = AdaIN(style_dim, dim_out) | |
if self.learned_sc: | |
self.conv1x1 = nn.Conv2D(dim_in, dim_out, 1, 1, 0, bias_attr=False) | |
def _shortcut(self, x): | |
x = self.upsample(x) | |
if self.learned_sc: | |
x = self.conv1x1(x) | |
return x | |
def _residual(self, x, s): | |
x = self.norm1(x, s) | |
x = self.actv(x) | |
x = self.upsample(x) | |
x = self.conv1(x) | |
x = self.norm2(x, s) | |
x = self.actv(x) | |
x = self.conv2(x) | |
return x | |
def forward(self, x, s): | |
out = self._residual(x, s) | |
if self.w_hpf == 0: | |
out = (out + self._shortcut(x)) / math.sqrt(2) | |
return out | |
class HighPass(nn.Layer): | |
def __init__(self, w_hpf): | |
super(HighPass, self).__init__() | |
self.filter = paddle.to_tensor([[-1, -1, -1], | |
[-1, 8., -1], | |
[-1, -1, -1]]) / w_hpf | |
def forward(self, x): | |
filter = self.filter.unsqueeze(0).unsqueeze(1).tile([x.shape[1], 1, 1, 1]) | |
return F.conv2d(x, filter, padding=1, groups=x.shape[1]) | |
class Generator(nn.Layer): | |
def __init__(self, dim_in=48, style_dim=48, max_conv_dim=48*8, w_hpf=1, F0_channel=0): | |
super().__init__() | |
self.stem = nn.Conv2D(1, dim_in, 3, 1, 1) | |
self.encode = nn.LayerList() | |
self.decode = nn.LayerList() | |
self.to_out = nn.Sequential( | |
nn.InstanceNorm2D(dim_in), | |
nn.LeakyReLU(0.2), | |
nn.Conv2D(dim_in, 1, 1, 1, 0)) | |
self.F0_channel = F0_channel | |
# down/up-sampling blocks | |
repeat_num = 4 #int(np.log2(img_size)) - 4 | |
if w_hpf > 0: | |
repeat_num += 1 | |
for lid in range(repeat_num): | |
if lid in [1, 3]: | |
_downtype = 'timepreserve' | |
else: | |
_downtype = 'half' | |
dim_out = min(dim_in*2, max_conv_dim) | |
self.encode.append( | |
ResBlk(dim_in, dim_out, normalize=True, downsample=_downtype)) | |
(self.decode.insert if lid else lambda i, sublayer: self.decode.append(sublayer))( | |
0, AdainResBlk(dim_out, dim_in, style_dim, | |
w_hpf=w_hpf, upsample=_downtype)) # stack-like | |
dim_in = dim_out | |
# bottleneck blocks (encoder) | |
for _ in range(2): | |
self.encode.append( | |
ResBlk(dim_out, dim_out, normalize=True)) | |
# F0 blocks | |
if F0_channel != 0: | |
self.decode.insert( | |
0, AdainResBlk(dim_out + int(F0_channel / 2), dim_out, style_dim, w_hpf=w_hpf)) | |
# bottleneck blocks (decoder) | |
for _ in range(2): | |
self.decode.insert( | |
0, AdainResBlk(dim_out + int(F0_channel / 2), dim_out + int(F0_channel / 2), style_dim, w_hpf=w_hpf)) | |
if F0_channel != 0: | |
self.F0_conv = nn.Sequential( | |
ResBlk(F0_channel, int(F0_channel / 2), normalize=True, downsample="half"), | |
) | |
if w_hpf > 0: | |
self.hpf = HighPass(w_hpf) | |
def forward(self, x, s, masks=None, F0=None): | |
x = self.stem(x) | |
cache = {} | |
for block in self.encode: | |
if (masks is not None) and (x.shape[2] in [32, 64, 128]): | |
cache[x.shape[2]] = x | |
x = block(x) | |
if F0 is not None: | |
F0 = self.F0_conv(F0) | |
F0 = F.adaptive_avg_pool2d(F0, [x.shape[-2], x.shape[-1]]) | |
x = paddle.concat([x, F0], axis=1) | |
for block in self.decode: | |
x = block(x, s) | |
if (masks is not None) and (x.shape[2] in [32, 64, 128]): | |
mask = masks[0] if x.shape[2] in [32] else masks[1] | |
mask = F.interpolate(mask, size=x.shape[2], mode='bilinear') | |
x = x + self.hpf(mask * cache[x.shape[2]]) | |
return self.to_out(x) | |
class MappingNetwork(nn.Layer): | |
def __init__(self, latent_dim=16, style_dim=48, num_domains=2, hidden_dim=384): | |
super().__init__() | |
layers = [] | |
layers += [nn.Linear(latent_dim, hidden_dim)] | |
layers += [nn.ReLU()] | |
for _ in range(3): | |
layers += [nn.Linear(hidden_dim, hidden_dim)] | |
layers += [nn.ReLU()] | |
self.shared = nn.Sequential(*layers) | |
self.unshared = nn.LayerList() | |
for _ in range(num_domains): | |
self.unshared.extend([nn.Sequential(nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, style_dim))]) | |
def forward(self, z, y): | |
h = self.shared(z) | |
out = [] | |
for layer in self.unshared: | |
out += [layer(h)] | |
out = paddle.stack(out, axis=1) # (batch, num_domains, style_dim) | |
idx = paddle.arange(y.shape[0]) | |
s = out[idx, y] # (batch, style_dim) | |
return s | |
class StyleEncoder(nn.Layer): | |
def __init__(self, dim_in=48, style_dim=48, num_domains=2, max_conv_dim=384): | |
super().__init__() | |
blocks = [] | |
blocks += [nn.Conv2D(1, dim_in, 3, 1, 1)] | |
repeat_num = 4 | |
for _ in range(repeat_num): | |
dim_out = min(dim_in*2, max_conv_dim) | |
blocks += [ResBlk(dim_in, dim_out, downsample='half')] | |
dim_in = dim_out | |
blocks += [nn.LeakyReLU(0.2)] | |
blocks += [nn.Conv2D(dim_out, dim_out, 5, 1, 0)] | |
blocks += [nn.AdaptiveAvgPool2D(1)] | |
blocks += [nn.LeakyReLU(0.2)] | |
self.shared = nn.Sequential(*blocks) | |
self.unshared = nn.LayerList() | |
for _ in range(num_domains): | |
self.unshared.append(nn.Linear(dim_out, style_dim)) | |
def forward(self, x, y): | |
h = self.shared(x) | |
h = h.reshape((h.shape[0], -1)) | |
out = [] | |
for layer in self.unshared: | |
out += [layer(h)] | |
out = paddle.stack(out, axis=1) # (batch, num_domains, style_dim) | |
idx = paddle.arange(y.shape[0]) | |
s = out[idx, y] # (batch, style_dim) | |
return s | |
class Discriminator(nn.Layer): | |
def __init__(self, dim_in=48, num_domains=2, max_conv_dim=384, repeat_num=4): | |
super().__init__() | |
# real/fake discriminator | |
self.dis = Discriminator2D(dim_in=dim_in, num_domains=num_domains, | |
max_conv_dim=max_conv_dim, repeat_num=repeat_num) | |
# adversarial classifier | |
self.cls = Discriminator2D(dim_in=dim_in, num_domains=num_domains, | |
max_conv_dim=max_conv_dim, repeat_num=repeat_num) | |
self.num_domains = num_domains | |
def forward(self, x, y): | |
return self.dis(x, y) | |
def classifier(self, x): | |
return self.cls.get_feature(x) | |
class LinearNorm(paddle.nn.Layer): | |
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): | |
super(LinearNorm, self).__init__() | |
self.linear_layer = paddle.nn.Linear(in_dim, out_dim, bias_attr=bias) | |
if float('.'.join(paddle.__version__.split('.')[:2])) >= 2.3: | |
gain = paddle.nn.initializer.calculate_gain(w_init_gain) | |
paddle.nn.initializer.XavierUniform()(self.linear_layer.weight) | |
self.linear_layer.weight.set_value(gain*self.linear_layer.weight) | |
def forward(self, x): | |
return self.linear_layer(x) | |
class Discriminator2D(nn.Layer): | |
def __init__(self, dim_in=48, num_domains=2, max_conv_dim=384, repeat_num=4): | |
super().__init__() | |
blocks = [] | |
blocks += [nn.Conv2D(1, dim_in, 3, 1, 1)] | |
for lid in range(repeat_num): | |
dim_out = min(dim_in*2, max_conv_dim) | |
blocks += [ResBlk(dim_in, dim_out, downsample='half')] | |
dim_in = dim_out | |
blocks += [nn.LeakyReLU(0.2)] | |
blocks += [nn.Conv2D(dim_out, dim_out, 5, 1, 0)] | |
blocks += [nn.LeakyReLU(0.2)] | |
blocks += [nn.AdaptiveAvgPool2D(1)] | |
blocks += [nn.Conv2D(dim_out, num_domains, 1, 1, 0)] | |
self.main = nn.Sequential(*blocks) | |
def get_feature(self, x): | |
out = self.main(x) | |
out = out.reshape((out.shape[0], -1)) # (batch, num_domains) | |
return out | |
def forward(self, x, y): | |
out = self.get_feature(x) | |
idx = paddle.arange(y.shape[0]) | |
out = out[idx, y] # (batch) | |
return out | |
def build_model(args, F0_model, ASR_model): | |
generator = Generator(args.dim_in, args.style_dim, args.max_conv_dim, w_hpf=args.w_hpf, F0_channel=args.F0_channel) | |
mapping_network = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains, hidden_dim=args.max_conv_dim) | |
style_encoder = StyleEncoder(args.dim_in, args.style_dim, args.num_domains, args.max_conv_dim) | |
discriminator = Discriminator(args.dim_in, args.num_domains, args.max_conv_dim, args.n_repeat) | |
generator_ema = copy.deepcopy(generator) | |
mapping_network_ema = copy.deepcopy(mapping_network) | |
style_encoder_ema = copy.deepcopy(style_encoder) | |
nets = Munch(generator=generator, | |
mapping_network=mapping_network, | |
style_encoder=style_encoder, | |
discriminator=discriminator, | |
f0_model=F0_model, | |
asr_model=ASR_model) | |
nets_ema = Munch(generator=generator_ema, | |
mapping_network=mapping_network_ema, | |
style_encoder=style_encoder_ema) | |
return nets, nets_ema |