|
|
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
import numpy as np |
|
|
|
def preprocess(img): |
|
|
|
image = Image.fromarray(img).convert('RGB') |
|
imsize = 196 |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((imsize, imsize)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
image = transform(image) |
|
image = image.unsqueeze(dim=0) |
|
|
|
return image |
|
|
|
|
|
def deprocess(image): |
|
|
|
image = image.clone() |
|
image = image.squeeze(0) |
|
image = image.permute(1,2,0) |
|
image = image.detach().numpy() |
|
image = image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) |
|
image = image.clip(0,1) |
|
|
|
return image |
|
|
|
def get_features(image, model): |
|
|
|
features = {} |
|
layers = { |
|
'0': 'layer_1', |
|
'5': 'layer_2', |
|
'10': 'layer_3', |
|
'19': 'layer_4', |
|
'28': 'layer_5' |
|
} |
|
x = image |
|
|
|
for name, layer in model._modules.items(): |
|
x = layer(x) |
|
if name in layers: |
|
features[layers[name]] = x |
|
|
|
return features |
|
|
|
|
|
def gram_matrix(image): |
|
|
|
b, c, h, w = image.size() |
|
image = image.view(c, h*w) |
|
gram = torch.mm(image, image.t()) |
|
return gram |
|
|