|
import re |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import init |
|
|
|
|
|
|
|
from torch.nn.utils import spectral_norm |
|
|
|
|
|
class SPADE(nn.Module): |
|
|
|
def __init__(self, config_text, norm_nc, label_nc): |
|
super().__init__() |
|
|
|
assert config_text.startswith('spade') |
|
parsed = re.search('spade(\\D+)(\\d)x\\d', config_text) |
|
param_free_norm_type = str(parsed.group(1)) |
|
ks = int(parsed.group(2)) |
|
|
|
if param_free_norm_type == 'instance': |
|
self.param_free_norm = nn.InstanceNorm2d(norm_nc) |
|
elif param_free_norm_type == 'syncbatch': |
|
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead') |
|
self.param_free_norm = nn.InstanceNorm2d(norm_nc) |
|
elif param_free_norm_type == 'batch': |
|
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) |
|
else: |
|
raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE') |
|
|
|
|
|
nhidden = 128 if norm_nc > 128 else norm_nc |
|
|
|
pw = ks // 2 |
|
self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU()) |
|
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False) |
|
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False) |
|
|
|
def forward(self, x, segmap): |
|
|
|
|
|
normalized = self.param_free_norm(x) |
|
|
|
|
|
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') |
|
actv = self.mlp_shared(segmap) |
|
gamma = self.mlp_gamma(actv) |
|
beta = self.mlp_beta(actv) |
|
|
|
|
|
out = normalized * gamma + beta |
|
|
|
return out |
|
|
|
|
|
class SPADEResnetBlock(nn.Module): |
|
""" |
|
ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that |
|
it takes in the segmentation map as input, learns the skip connection if necessary, |
|
and applies normalization first and then convolution. |
|
This architecture seemed like a standard architecture for unconditional or |
|
class-conditional GAN architecture using residual block. |
|
The code was inspired from https://github.com/LMescheder/GAN_stability. |
|
""" |
|
|
|
def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3): |
|
super().__init__() |
|
|
|
self.learned_shortcut = (fin != fout) |
|
fmiddle = min(fin, fout) |
|
|
|
|
|
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) |
|
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) |
|
if self.learned_shortcut: |
|
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) |
|
|
|
|
|
if 'spectral' in norm_g: |
|
self.conv_0 = spectral_norm(self.conv_0) |
|
self.conv_1 = spectral_norm(self.conv_1) |
|
if self.learned_shortcut: |
|
self.conv_s = spectral_norm(self.conv_s) |
|
|
|
|
|
spade_config_str = norm_g.replace('spectral', '') |
|
self.norm_0 = SPADE(spade_config_str, fin, semantic_nc) |
|
self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc) |
|
if self.learned_shortcut: |
|
self.norm_s = SPADE(spade_config_str, fin, semantic_nc) |
|
|
|
|
|
|
|
def forward(self, x, seg): |
|
x_s = self.shortcut(x, seg) |
|
dx = self.conv_0(self.act(self.norm_0(x, seg))) |
|
dx = self.conv_1(self.act(self.norm_1(dx, seg))) |
|
out = x_s + dx |
|
return out |
|
|
|
def shortcut(self, x, seg): |
|
if self.learned_shortcut: |
|
x_s = self.conv_s(self.norm_s(x, seg)) |
|
else: |
|
x_s = x |
|
return x_s |
|
|
|
def act(self, x): |
|
return F.leaky_relu(x, 2e-1) |
|
|
|
|
|
class BaseNetwork(nn.Module): |
|
""" A basis for hifacegan archs with custom initialization """ |
|
|
|
def init_weights(self, init_type='normal', gain=0.02): |
|
|
|
def init_func(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('BatchNorm2d') != -1: |
|
if hasattr(m, 'weight') and m.weight is not None: |
|
init.normal_(m.weight.data, 1.0, gain) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
init.constant_(m.bias.data, 0.0) |
|
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
|
if init_type == 'normal': |
|
init.normal_(m.weight.data, 0.0, gain) |
|
elif init_type == 'xavier': |
|
init.xavier_normal_(m.weight.data, gain=gain) |
|
elif init_type == 'xavier_uniform': |
|
init.xavier_uniform_(m.weight.data, gain=1.0) |
|
elif init_type == 'kaiming': |
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
|
elif init_type == 'orthogonal': |
|
init.orthogonal_(m.weight.data, gain=gain) |
|
elif init_type == 'none': |
|
m.reset_parameters() |
|
else: |
|
raise NotImplementedError(f'initialization method [{init_type}] is not implemented') |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
self.apply(init_func) |
|
|
|
|
|
for m in self.children(): |
|
if hasattr(m, 'init_weights'): |
|
m.init_weights(init_type, gain) |
|
|
|
def forward(self, x): |
|
pass |
|
|
|
|
|
def lip2d(x, logit, kernel=3, stride=2, padding=1): |
|
weight = logit.exp() |
|
return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding) |
|
|
|
|
|
class SoftGate(nn.Module): |
|
COEFF = 12.0 |
|
|
|
def forward(self, x): |
|
return torch.sigmoid(x).mul(self.COEFF) |
|
|
|
|
|
class SimplifiedLIP(nn.Module): |
|
|
|
def __init__(self, channels): |
|
super(SimplifiedLIP, self).__init__() |
|
self.logit = nn.Sequential( |
|
nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True), |
|
SoftGate()) |
|
|
|
def init_layer(self): |
|
self.logit[0].weight.data.fill_(0.0) |
|
|
|
def forward(self, x): |
|
frac = lip2d(x, self.logit(x)) |
|
return frac |
|
|
|
|
|
class LIPEncoder(BaseNetwork): |
|
"""Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)""" |
|
|
|
def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d): |
|
super().__init__() |
|
self.sw = sw |
|
self.sh = sh |
|
self.max_ratio = 16 |
|
|
|
kw = 3 |
|
pw = (kw - 1) // 2 |
|
|
|
model = [ |
|
nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False), |
|
norm_layer(ngf), |
|
nn.ReLU(), |
|
] |
|
cur_ratio = 1 |
|
for i in range(n_2xdown): |
|
next_ratio = min(cur_ratio * 2, self.max_ratio) |
|
model += [ |
|
SimplifiedLIP(ngf * cur_ratio), |
|
nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw), |
|
norm_layer(ngf * next_ratio), |
|
] |
|
cur_ratio = next_ratio |
|
if i < n_2xdown - 1: |
|
model += [nn.ReLU(inplace=True)] |
|
|
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
def get_nonspade_norm_layer(norm_type='instance'): |
|
|
|
def get_out_channel(layer): |
|
if hasattr(layer, 'out_channels'): |
|
return getattr(layer, 'out_channels') |
|
return layer.weight.size(0) |
|
|
|
|
|
def add_norm_layer(layer): |
|
nonlocal norm_type |
|
if norm_type.startswith('spectral'): |
|
layer = spectral_norm(layer) |
|
subnorm_type = norm_type[len('spectral'):] |
|
|
|
if subnorm_type == 'none' or len(subnorm_type) == 0: |
|
return layer |
|
|
|
|
|
|
|
if getattr(layer, 'bias', None) is not None: |
|
delattr(layer, 'bias') |
|
layer.register_parameter('bias', None) |
|
|
|
if subnorm_type == 'batch': |
|
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) |
|
elif subnorm_type == 'sync_batch': |
|
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead') |
|
|
|
|
|
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) |
|
elif subnorm_type == 'instance': |
|
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) |
|
else: |
|
raise ValueError(f'normalization layer {subnorm_type} is not recognized') |
|
|
|
return nn.Sequential(layer, norm_layer) |
|
|
|
print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.') |
|
return add_norm_layer |
|
|