|
import numpy as np |
|
import os |
|
import argparse |
|
import cv2 |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from basicsr.models import create_model |
|
from basicsr.utils.options import parse |
|
from skimage import img_as_ubyte |
|
from torch.cuda.amp import autocast, GradScaler |
|
def load_img(filepath): |
|
return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) |
|
def save_img(filepath, img): |
|
cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) |
|
def self_ensemble(x, model): |
|
def forward_transformed(x, hflip, vflip, rotate, model): |
|
if hflip: |
|
x = torch.flip(x, (-2,)) |
|
if vflip: |
|
x = torch.flip(x, (-1,)) |
|
if rotate: |
|
x = torch.rot90(x, dims=(-2, -1)) |
|
x = model(x) |
|
if rotate: |
|
x = torch.rot90(x, dims=(-2, -1), k=3) |
|
if vflip: |
|
x = torch.flip(x, (-1,)) |
|
if hflip: |
|
x = torch.flip(x, (-2,)) |
|
return x |
|
t = [] |
|
for hflip in [False, True]: |
|
for vflip in [False, True]: |
|
for rot in [False, True]: |
|
t.append(forward_transformed(x, hflip, vflip, rot, model)) |
|
t = torch.stack(t) |
|
return torch.mean(t, dim=0) |
|
|
|
|
|
|
|
gpu_list = ','.join(str(x) for x in '0') |
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list |
|
print('export CUDA_VISIBLE_DEVICES=' + gpu_list) |
|
|
|
|
|
opt = parse("RetinexFormer_FiveK.yml", is_train=False) |
|
opt['dist'] = False |
|
print(opt) |
|
|
|
|
|
model_restoration = create_model(opt).net_g |
|
checkpoint = torch.load("pretrained_weights/FiveK.pth") |
|
model_restoration.load_state_dict(checkpoint['params']) |
|
print("===>Testing using weights: ", "pretrained_weights/FiveK.pth") |
|
model_restoration.cuda() |
|
model_restoration = nn.DataParallel(model_restoration) |
|
model_restoration.eval() |
|
|
|
|
|
def process_image(inp_path, model_restoration, out_dir, factor=4): |
|
torch.cuda.ipc_collect() |
|
torch.cuda.empty_cache() |
|
|
|
img = np.float32(load_img(inp_path)) / 255. |
|
|
|
|
|
|
|
max_height = 1024 |
|
aspect_ratio = img.shape[1] / img.shape[0] |
|
new_width = int(max_height * aspect_ratio) |
|
img = cv2.resize(img, (new_width, max_height), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
img = torch.from_numpy(img).permute(2, 0, 1) |
|
input_ = img.unsqueeze(0).cuda() |
|
|
|
|
|
b, c, h, w = input_.shape |
|
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor |
|
padh = H - h if h % factor != 0 else 0 |
|
padw = W - w if w % factor != 0 else 0 |
|
input_ = F.pad(input_, (0, padw, 0, padh), 'reflect') |
|
|
|
scaler = GradScaler() |
|
|
|
with autocast(): |
|
if h < 3000 and w < 3000: |
|
if 1 == 0: |
|
restored = self_ensemble(input_, model_restoration) |
|
else: |
|
restored = model_restoration(input_) |
|
else: |
|
|
|
input_1 = input_[:, :, :, 1::2] |
|
input_2 = input_[:, :, :, 0::2] |
|
if 1 == 0: |
|
restored_1 = self_ensemble(input_1, model_restoration) |
|
restored_2 = self_ensemble(input_2, model_restoration) |
|
else: |
|
restored_1 = model_restoration(input_1) |
|
restored_2 = model_restoration(input_2) |
|
restored = torch.zeros_like(input_) |
|
restored[:, :, :, 1::2] = restored_1 |
|
restored[:, :, :, 0::2] = restored_2 |
|
|
|
|
|
restored = restored[:, :, :h, :w] |
|
|
|
restored = torch.clamp(restored, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() |
|
|
|
|
|
if True: |
|
save_img(os.path.join(out_dir, os.path.splitext(os.path.split(inp_path)[-1])[0] + '.png'), img_as_ubyte(restored)) |
|
|
|
|
|
|