cyclegan-cut / models /sincut_model.py
qninhdt's picture
Upload 68 files
ec1cb04 verified
import torch
from .cut_model import CUTModel
class SinCUTModel(CUTModel):
""" This class implements the single image translation model (Fig 9) of
Contrastive Learning for Unpaired Image-to-Image Translation
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
ECCV, 2020
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
parser = CUTModel.modify_commandline_options(parser, is_train)
parser.add_argument('--lambda_R1', type=float, default=1.0,
help='weight for the R1 gradient penalty')
parser.add_argument('--lambda_identity', type=float, default=1.0,
help='the "identity preservation loss"')
parser.set_defaults(nce_includes_all_negatives_from_minibatch=True,
dataset_mode="singleimage",
netG="stylegan2",
stylegan2_G_num_downsampling=1,
netD="stylegan2",
gan_mode="nonsaturating",
num_patches=1,
nce_layers="0,2,4",
lambda_NCE=4.0,
ngf=10,
ndf=8,
lr=0.002,
beta1=0.0,
beta2=0.99,
load_size=1024,
crop_size=64,
preprocess="zoom_and_patch",
)
if is_train:
parser.set_defaults(preprocess="zoom_and_patch",
batch_size=16,
save_epoch_freq=1,
save_latest_freq=20000,
n_epochs=8,
n_epochs_decay=8,
)
else:
parser.set_defaults(preprocess="none", # load the whole image as it is
batch_size=1,
num_test=1,
)
return parser
def __init__(self, opt):
super().__init__(opt)
if self.isTrain:
if opt.lambda_R1 > 0.0:
self.loss_names += ['D_R1']
if opt.lambda_identity > 0.0:
self.loss_names += ['idt']
def compute_D_loss(self):
self.real_B.requires_grad_()
GAN_loss_D = super().compute_D_loss()
self.loss_D_R1 = self.R1_loss(self.pred_real, self.real_B)
self.loss_D = GAN_loss_D + self.loss_D_R1
return self.loss_D
def compute_G_loss(self):
CUT_loss_G = super().compute_G_loss()
self.loss_idt = torch.nn.functional.l1_loss(self.idt_B, self.real_B) * self.opt.lambda_identity
return CUT_loss_G + self.loss_idt
def R1_loss(self, real_pred, real_img):
grad_real, = torch.autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True, retain_graph=True)
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty * (self.opt.lambda_R1 * 0.5)