|
from torch import nn |
|
import torch.nn.functional as F |
|
from facerender.modules.util import kp2gaussian |
|
import torch |
|
|
|
|
|
class DownBlock2d(nn.Module): |
|
""" |
|
Simple block for processing video (encoder). |
|
""" |
|
|
|
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): |
|
super(DownBlock2d, self).__init__() |
|
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) |
|
|
|
if sn: |
|
self.conv = nn.utils.spectral_norm(self.conv) |
|
|
|
if norm: |
|
self.norm = nn.InstanceNorm2d(out_features, affine=True) |
|
else: |
|
self.norm = None |
|
self.pool = pool |
|
|
|
def forward(self, x): |
|
out = x |
|
out = self.conv(out) |
|
if self.norm: |
|
out = self.norm(out) |
|
out = F.leaky_relu(out, 0.2) |
|
if self.pool: |
|
out = F.avg_pool2d(out, (2, 2)) |
|
return out |
|
|
|
|
|
class Discriminator(nn.Module): |
|
""" |
|
Discriminator similar to Pix2Pix |
|
""" |
|
|
|
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, |
|
sn=False, **kwargs): |
|
super(Discriminator, self).__init__() |
|
|
|
down_blocks = [] |
|
for i in range(num_blocks): |
|
down_blocks.append( |
|
DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), |
|
min(max_features, block_expansion * (2 ** (i + 1))), |
|
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) |
|
|
|
self.down_blocks = nn.ModuleList(down_blocks) |
|
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) |
|
if sn: |
|
self.conv = nn.utils.spectral_norm(self.conv) |
|
|
|
def forward(self, x): |
|
feature_maps = [] |
|
out = x |
|
|
|
for down_block in self.down_blocks: |
|
feature_maps.append(down_block(out)) |
|
out = feature_maps[-1] |
|
prediction_map = self.conv(out) |
|
|
|
return feature_maps, prediction_map |
|
|
|
|
|
class MultiScaleDiscriminator(nn.Module): |
|
""" |
|
Multi-scale (scale) discriminator |
|
""" |
|
|
|
def __init__(self, scales=(), **kwargs): |
|
super(MultiScaleDiscriminator, self).__init__() |
|
self.scales = scales |
|
discs = {} |
|
for scale in scales: |
|
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) |
|
self.discs = nn.ModuleDict(discs) |
|
|
|
def forward(self, x): |
|
out_dict = {} |
|
for scale, disc in self.discs.items(): |
|
scale = str(scale).replace('-', '.') |
|
key = 'prediction_' + scale |
|
feature_maps, prediction_map = disc(x[key]) |
|
out_dict['feature_maps_' + scale] = feature_maps |
|
out_dict['prediction_map_' + scale] = prediction_map |
|
return out_dict |
|
|