Traly's picture
fix-1
e9b996f
import os
from collections import OrderedDict
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import transforms
from sam_diffsr.utils_sr.hparams import set_hparams, hparams
from sam_diffsr.utils_sr.matlab_resize import imresize
from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer_ori
ROOT_PATH = os.path.dirname(__file__)
class sam_diffsr_demo:
def __init__(self):
set_hparams()
ckpt_path = os.path.join(ROOT_PATH, 'weight/model_ckpt_steps_400000.ckpt')
self.model_init(ckpt_path)
def get_img_data(self, img_PIL, hparams, sr_scale=4):
img_lr = img_PIL.convert('RGB')
img_lr = np.uint8(np.asarray(img_lr))
h, w, c = img_lr.shape
h, w = h * sr_scale, w * sr_scale
h = h - h % (sr_scale * 2)
w = w - w % (sr_scale * 2)
h_l = h // sr_scale
w_l = w // sr_scale
img_lr = img_lr[:h_l, :w_l]
to_tensor_norm = transforms.Compose([
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]]
img_lr = torch.unsqueeze(img_lr, dim=0)
img_lr_up = torch.unsqueeze(img_lr_up, dim=0)
return img_lr, img_lr_up
def load_checkpoint(self, ckpt_path):
checkpoint = torch.load(ckpt_path, map_location='cpu')
print(f'loding check from: {ckpt_path}')
stat_dict = checkpoint['state_dict']['model']
new_state_dict = OrderedDict()
for k, v in stat_dict.items():
if k[:7] == 'module.':
k = k[7:] # εŽ»ζŽ‰ `module.`
new_state_dict[k] = v
self.model.model.load_state_dict(new_state_dict)
self.model.model.cuda()
del checkpoint
torch.cuda.empty_cache()
def model_init(self, ckpt_path):
self.model = trainer_ori()
self.model.build_model()
self.load_checkpoint(ckpt_path)
torch.backends.cudnn.benchmark = False
def infer(self, img_PIL):
with torch.no_grad():
self.model.model.eval()
img_lr, img_lr_up = self.get_img_data(img_PIL, hparams, sr_scale=4)
img_lr = img_lr.to('cuda')
img_lr_up = img_lr_up.to('cuda')
img_sr, _ = self.model.model.sample(img_lr, img_lr_up, img_lr_up.shape)
img_sr = img_sr.clamp(-1, 1)
img_sr = self.model.tensor2img(img_sr)[0]
img_sr = Image.fromarray(img_sr)
return img_sr