import torch from data_setup import gram_matrix def content_loss(target, content): loss = torch.mean((target - content) ** 2) return loss def style_loss(target_features, style_grams): loss = 0 for layer in target_features: target_f = target_features[layer] target_gram = gram_matrix(target_f) style_gram = style_grams[layer] b,c,h,w = target_f.shape layer_loss = 0.2 * torch.mean((target_gram - style_gram) ** 2) loss += layer_loss/(c*h*w) return loss def total_loss(content_loss, style_loss, alpha, beta): loss = alpha * content_loss + beta * style_loss return loss