Spaces:
Runtime error
Runtime error
import os | |
from collections import OrderedDict | |
import numpy as np | |
import torch | |
from PIL import Image | |
from torchvision.transforms import transforms | |
from sam_diffsr.utils_sr.hparams import set_hparams, hparams | |
from sam_diffsr.utils_sr.matlab_resize import imresize | |
from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer_ori | |
ROOT_PATH = os.path.dirname(__file__) | |
class sam_diffsr_demo: | |
def __init__(self): | |
set_hparams() | |
ckpt_path = os.path.join(ROOT_PATH, 'weight/model_ckpt_steps_400000.ckpt') | |
self.model_init(ckpt_path) | |
def get_img_data(self, img_PIL, hparams, sr_scale=4): | |
img_lr = img_PIL.convert('RGB') | |
img_lr = np.uint8(np.asarray(img_lr)) | |
h, w, c = img_lr.shape | |
h, w = h * sr_scale, w * sr_scale | |
h = h - h % (sr_scale * 2) | |
w = w - w % (sr_scale * 2) | |
h_l = h // sr_scale | |
w_l = w // sr_scale | |
img_lr = img_lr[:h_l, :w_l] | |
to_tensor_norm = transforms.Compose([ | |
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C] | |
img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]] | |
img_lr = torch.unsqueeze(img_lr, dim=0) | |
img_lr_up = torch.unsqueeze(img_lr_up, dim=0) | |
return img_lr, img_lr_up | |
def load_checkpoint(self, ckpt_path): | |
checkpoint = torch.load(ckpt_path, map_location='cpu') | |
print(f'loding check from: {ckpt_path}') | |
stat_dict = checkpoint['state_dict']['model'] | |
new_state_dict = OrderedDict() | |
for k, v in stat_dict.items(): | |
if k[:7] == 'module.': | |
k = k[7:] # ε»ζ `module.` | |
new_state_dict[k] = v | |
self.model.model.load_state_dict(new_state_dict) | |
self.model.model.cuda() | |
del checkpoint | |
torch.cuda.empty_cache() | |
def model_init(self, ckpt_path): | |
self.model = trainer_ori() | |
self.model.build_model() | |
self.load_checkpoint(ckpt_path) | |
torch.backends.cudnn.benchmark = False | |
def infer(self, img_PIL): | |
with torch.no_grad(): | |
self.model.model.eval() | |
img_lr, img_lr_up = self.get_img_data(img_PIL, hparams, sr_scale=4) | |
img_lr = img_lr.to('cuda') | |
img_lr_up = img_lr_up.to('cuda') | |
img_sr, _ = self.model.model.sample(img_lr, img_lr_up, img_lr_up.shape) | |
img_sr = img_sr.clamp(-1, 1) | |
img_sr = self.model.tensor2img(img_sr)[0] | |
img_sr = Image.fromarray(img_sr) | |
return img_sr | |