|
import torch |
|
import torch.nn as nn |
|
import torch.nn.parallel |
|
import torch.optim as optim |
|
from torch.autograd import Variable |
|
|
|
|
|
def weights_init(m): |
|
classname = m.__class__.__name__ |
|
|
|
if classname.find('Conv') != -1: |
|
m.weight.data.normal_(0.0, 0.02) |
|
|
|
elif classname.find('BatchNorm') != -1: |
|
m.weight.data.normal_(1.0, 0.02) |
|
m.bias.data.fill_(0) |
|
|
|
|
|
''' Generator network for 128x128 RGB images ''' |
|
class G(nn.Module): |
|
|
|
def __init__(self): |
|
super(G, self).__init__() |
|
|
|
self.main = nn.Sequential( |
|
|
|
nn.Conv2d(3, 16, 4, 2, 1), |
|
nn.BatchNorm2d(16), |
|
nn.ReLU(True), |
|
nn.Conv2d(16, 32, 4, 2, 1), |
|
nn.BatchNorm2d(32), |
|
nn.ReLU(True), |
|
nn.Conv2d(32, 64, 4, 2, 1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(True), |
|
nn.Conv2d(64, 128, 4, 2, 1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(True), |
|
nn.Conv2d(128, 256, 4, 2, 1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(True), |
|
nn.Conv2d(256, 512, 4, 2, 1), |
|
nn.MaxPool2d((2,2)), |
|
|
|
|
|
nn.ConvTranspose2d(512, 256, 4, 1, 0, bias = False), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(64, 32, 4, 2, 1, bias = False), |
|
nn.BatchNorm2d(32), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(32, 16, 4, 2, 1, bias = False), |
|
nn.BatchNorm2d(16), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(16, 3, 4, 2, 1, bias = False), |
|
nn.Tanh() |
|
) |
|
|
|
|
|
def forward(self, input): |
|
output = self.main(input) |
|
return output |
|
|
|
|
|
''' Discriminator network for 128x128 RGB images ''' |
|
class D(nn.Module): |
|
|
|
def __init__(self): |
|
super(D, self).__init__() |
|
self.main = nn.Sequential( |
|
nn.Conv2d(3, 16, 4, 2, 1), |
|
nn.LeakyReLU(0.2, inplace = True), |
|
nn.Conv2d(16, 32, 4, 2, 1), |
|
nn.BatchNorm2d(32), |
|
nn.LeakyReLU(0.2, inplace = True), |
|
nn.Conv2d(32, 64, 4, 2, 1), |
|
nn.BatchNorm2d(64), |
|
nn.LeakyReLU(0.2, inplace = True), |
|
nn.Conv2d(64, 128, 4, 2, 1), |
|
nn.BatchNorm2d(128), |
|
nn.LeakyReLU(0.2, inplace = True), |
|
nn.Conv2d(128, 256, 4, 2, 1), |
|
nn.BatchNorm2d(256), |
|
nn.LeakyReLU(0.2, inplace = True), |
|
nn.Conv2d(256, 512, 4, 2, 1), |
|
nn.BatchNorm2d(512), |
|
nn.LeakyReLU(0.2, inplace = True), |
|
nn.Conv2d(512, 1, 4, 2, 1, bias = False), |
|
nn.Sigmoid() |
|
) |
|
|
|
|
|
def forward(self, input): |
|
output = self.main(input) |
|
return output.view(-1) |
|
|