pixelization / models /c2pGen.py
NoCrypt's picture
Update models/c2pGen.py
791b9ee verified
from .basic_layer import *
import torchvision.models as models
import os
class AliasNet(nn.Module):
def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, activ='relu', pad_type='reflect'):
super(AliasNet, self).__init__()
self.RGBEnc = AliasRGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
self.RGBDec = AliasRGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='in',
activ=activ, pad_type=pad_type)
def forward(self, x):
x = self.RGBEnc(x)
x = self.RGBDec(x)
return x
class AliasRGBEncoder(nn.Module):
def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
super(AliasRGBEncoder, self).__init__()
self.model = []
self.model += [AliasConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
# downsampling blocks
for i in range(n_downsample):
self.model += [AliasConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
# residual blocks
self.model += [AliasResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
self.output_dim = dim
def forward(self, x):
return self.model(x)
class AliasRGBDecoder(nn.Module):
def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
super(AliasRGBDecoder, self).__init__()
# self.model = []
# # AdaIN residual blocks
# self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
# # upsampling blocks
# for i in range(n_upsample):
# self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
# ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
# dim //= 2
# # use reflection padding in the last conv layer
# self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
# self.model = nn.Sequential(*self.model)
self.Res_Blocks = AliasResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
self.conv_1 = AliasConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
dim //= 2
self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
self.conv_2 = AliasConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
dim //= 2
self.conv_3 = AliasConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
def forward(self, x):
x = self.Res_Blocks(x)
# print(x.shape)
x = self.upsample_block1(x)
# print(x.shape)
x = self.conv_1(x)
# print(x_small.shape)
x = self.upsample_block2(x)
# print(x.shape)
x = self.conv_2(x)
# print(x_middle.shape)
x = self.conv_3(x)
# print(x_big.shape)
return x
class C2PGen(nn.Module):
def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, style_dim, mlp_dim, activ='relu', pad_type='reflect'):
super(C2PGen, self).__init__()
self.PBEnc = PixelBlockEncoder(input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type)
self.RGBEnc = RGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
self.RGBDec = RGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='adain',
activ=activ, pad_type=pad_type)
self.MLP = MLP(style_dim, 2048, mlp_dim, 3, norm='none', activ=activ)
def forward(self, clipart, pixelart, s=1):
feature = self.RGBEnc(clipart)
code = self.PBEnc(pixelart)
result, cellcode = self.fuse(feature, code, s)
return result#, cellcode #return cellcode when visualizing the cell size code
def fuse(self, content, style_code, s=1):
#print("MLP input:code's shape:", style_code.shape)
adain_params = self.MLP(style_code) * s # [batch,2048]
#print("MLP output:adain_params's shape", adain_params.shape)
#self.assign_adain_params(adain_params, self.RGBDec)
images = self.RGBDec(content, adain_params)
return images, adain_params
def assign_adain_params(self, adain_params, model):
# assign the adain_params to the AdaIN layers in model
for m in model.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
mean = adain_params[:, :m.num_features]
std = adain_params[:, m.num_features:2 * m.num_features]
m.bias = mean.contiguous().view(-1)
m.weight = std.contiguous().view(-1)
if adain_params.size(1) > 2 * m.num_features:
adain_params = adain_params[:, 2 * m.num_features:]
def get_num_adain_params(self, model):
# return the number of AdaIN parameters needed by the model
num_adain_params = 0
for m in model.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
num_adain_params += 2 * m.num_features
return num_adain_params
class PixelBlockEncoder(nn.Module):
def __init__(self, input_dim, dim, style_dim, norm, activ, pad_type):
super(PixelBlockEncoder, self).__init__()
vgg19 = models.vgg.vgg19()
vgg19.classifier._modules['6'] = nn.Linear(4096, 7, bias=True)
vgg19.load_state_dict(torch.load('./pixelart_vgg19.pth' if not os.environ['PIX_MODEL'] else os.environ['PIX_MODEL'], map_location=torch.device('cpu'), weights_only=True))
self.vgg = vgg19.features
for p in self.vgg.parameters():
p.requires_grad = False
# vgg19 = models.vgg.vgg19(pretrained=False)
# vgg19.load_state_dict(torch.load('./vgg.pth'))
# self.vgg = vgg19.features
# for p in self.vgg.parameters():
# p.requires_grad = False
self.conv1 = ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type) # 3->64,concat
dim = dim * 2
self.conv2 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 128->128
dim = dim * 2
self.conv3 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 256->256
dim = dim * 2
self.conv4 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 512->512
dim = dim * 2
self.model = []
self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
self.model = nn.Sequential(*self.model)
self.output_dim = dim
def get_features(self, image, model, layers=None):
if layers is None:
layers = {'0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1'}
features = {}
x = image
# model._modules is a dictionary holding each module in the model
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def componet_enc(self, x):
# x [16,3,256,256]
# factor_img [16,7,256,256]
vgg_aux = self.get_features(x, self.vgg) # x是3通道灰度图
#x = torch.cat([x, factor_img], dim=1) # [16,3+7,256,256]
x = self.conv1(x) # 64 256 256
x = torch.cat([x, vgg_aux['conv1_1']], dim=1) # 128 256 256
x = self.conv2(x) # 128 128 128
x = torch.cat([x, vgg_aux['conv2_1']], dim=1) # 256 128 128
x = self.conv3(x) # 256 64 64
x = torch.cat([x, vgg_aux['conv3_1']], dim=1) # 512 64 64
x = self.conv4(x) # 512 32 32
x = torch.cat([x, vgg_aux['conv4_1']], dim=1) # 1024 32 32
x = self.model(x)
return x
def forward(self, x):
code = self.componet_enc(x)
return code
class RGBEncoder(nn.Module):
def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
super(RGBEncoder, self).__init__()
self.model = []
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
# downsampling blocks
for i in range(n_downsample):
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
# residual blocks
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
self.output_dim = dim
def forward(self, x):
return self.model(x)
class RGBDecoder(nn.Module):
def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
super(RGBDecoder, self).__init__()
# self.model = []
# # AdaIN residual blocks
# self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
# # upsampling blocks
# for i in range(n_upsample):
# self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
# ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
# dim //= 2
# # use reflection padding in the last conv layer
# self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
# self.model = nn.Sequential(*self.model)
#self.Res_Blocks = ModulationResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
self.mod_conv_1 = ModulationConvBlock(256,256,3)
self.mod_conv_2 = ModulationConvBlock(256,256,3)
self.mod_conv_3 = ModulationConvBlock(256,256,3)
self.mod_conv_4 = ModulationConvBlock(256,256,3)
self.mod_conv_5 = ModulationConvBlock(256,256,3)
self.mod_conv_6 = ModulationConvBlock(256,256,3)
self.mod_conv_7 = ModulationConvBlock(256,256,3)
self.mod_conv_8 = ModulationConvBlock(256,256,3)
self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
self.conv_1 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
dim //= 2
self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
self.conv_2 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
dim //= 2
self.conv_3 = ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
# def forward(self, x):
# residual = x
# out = self.model(x)
# out += residual
# return out
def forward(self, x, code):
residual = x
x = self.mod_conv_1(x, code[:, :256])
x = self.mod_conv_2(x, code[:, 256*1:256*2])
x += residual
residual = x
x = self.mod_conv_2(x, code[:, 256*2:256 * 3])
x = self.mod_conv_2(x, code[:, 256*3:256 * 4])
x += residual
residual =x
x = self.mod_conv_2(x, code[:, 256*4:256 * 5])
x = self.mod_conv_2(x, code[:, 256*5:256 * 6])
x += residual
residual = x
x = self.mod_conv_2(x, code[:, 256*6:256 * 7])
x = self.mod_conv_2(x, code[:, 256*7:256 * 8])
x += residual
# print(x.shape)
x = self.upsample_block1(x)
# print(x.shape)
x = self.conv_1(x)
# print(x_small.shape)
x = self.upsample_block2(x)
# print(x.shape)
x = self.conv_2(x)
# print(x_middle.shape)
x = self.conv_3(x)
# print(x_big.shape)
return x