Spaces:
Running
Running
File size: 3,288 Bytes
2c9c37b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from .basic_layer import *
class P2CGen(nn.Module):
def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, activ='relu', pad_type='reflect'):
super(P2CGen, self).__init__()
self.RGBEnc = RGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
self.RGBDec = RGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='in',
activ=activ, pad_type=pad_type)
def forward(self, x):
x = self.RGBEnc(x)
# print("encoder->>", x.shape)
x = self.RGBDec(x)
# print(x_small.shape)
# print(x_middle.shape)
# print(x_big.shape)
#return y_small, y_middle, y_big
return x
class RGBEncoder(nn.Module):
def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
super(RGBEncoder, self).__init__()
self.model = []
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
# downsampling blocks
for i in range(n_downsample):
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
# residual blocks
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
self.output_dim = dim
def forward(self, x):
return self.model(x)
class RGBDecoder(nn.Module):
def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
super(RGBDecoder, self).__init__()
# self.model = []
# # AdaIN residual blocks
# self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
# # upsampling blocks
# for i in range(n_upsample):
# self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
# ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
# dim //= 2
# # use reflection padding in the last conv layer
# self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
# self.model = nn.Sequential(*self.model)
self.Res_Blocks = ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
self.conv_1 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
dim //= 2
self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
self.conv_2 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
dim //= 2
self.conv_3 = ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
def forward(self, x):
x = self.Res_Blocks(x)
# print(x.shape)
x = self.upsample_block1(x)
# print(x.shape)
x = self.conv_1(x)
# print(x_small.shape)
x = self.upsample_block2(x)
# print(x.shape)
x = self.conv_2(x)
# print(x_middle.shape)
x = self.conv_3(x)
# print(x_big.shape)
return x
|