retouch / inference.py
DuySota's picture
Upload folder using huggingface_hub
defb4d0 verified
import numpy as np
import os
import argparse
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.models import create_model
from basicsr.utils.options import parse
from skimage import img_as_ubyte
from torch.cuda.amp import autocast, GradScaler
def load_img(filepath):
return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
def save_img(filepath, img):
cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def self_ensemble(x, model):
def forward_transformed(x, hflip, vflip, rotate, model):
if hflip:
x = torch.flip(x, (-2,))
if vflip:
x = torch.flip(x, (-1,))
if rotate:
x = torch.rot90(x, dims=(-2, -1))
x = model(x)
if rotate:
x = torch.rot90(x, dims=(-2, -1), k=3)
if vflip:
x = torch.flip(x, (-1,))
if hflip:
x = torch.flip(x, (-2,))
return x
t = []
for hflip in [False, True]:
for vflip in [False, True]:
for rot in [False, True]:
t.append(forward_transformed(x, hflip, vflip, rot, model))
t = torch.stack(t)
return torch.mean(t, dim=0)
# Set GPU
gpu_list = ','.join(str(x) for x in '0')
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
# Load YAML configuration
opt = parse("RetinexFormer_FiveK.yml", is_train=False)
opt['dist'] = False
print(opt)
# Load model
model_restoration = create_model(opt).net_g
checkpoint = torch.load("pretrained_weights/FiveK.pth")
model_restoration.load_state_dict(checkpoint['params'])
print("===>Testing using weights: ", "pretrained_weights/FiveK.pth")
model_restoration.cuda()
model_restoration = nn.DataParallel(model_restoration)
model_restoration.eval()
def process_image(inp_path, model_restoration, out_dir, factor=4):
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
img = np.float32(load_img(inp_path)) / 255.
# Resize image to have height 1024px while maintaining aspect ratio
max_height = 1024
aspect_ratio = img.shape[1] / img.shape[0]
new_width = int(max_height * aspect_ratio)
img = cv2.resize(img, (new_width, max_height), interpolation=cv2.INTER_AREA)
img = torch.from_numpy(img).permute(2, 0, 1)
input_ = img.unsqueeze(0).cuda()
# Padding in case images are not multiples of 4
b, c, h, w = input_.shape
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
scaler = GradScaler() # for mixed precision
with autocast(): # enable mixed precision
if h < 3000 and w < 3000:
if 1 == 0:
restored = self_ensemble(input_, model_restoration)
else:
restored = model_restoration(input_)
else:
# split and test
input_1 = input_[:, :, :, 1::2]
input_2 = input_[:, :, :, 0::2]
if 1 == 0:
restored_1 = self_ensemble(input_1, model_restoration)
restored_2 = self_ensemble(input_2, model_restoration)
else:
restored_1 = model_restoration(input_1)
restored_2 = model_restoration(input_2)
restored = torch.zeros_like(input_)
restored[:, :, :, 1::2] = restored_1
restored[:, :, :, 0::2] = restored_2
# Unpad images to original dimensions
restored = restored[:, :, :h, :w]
restored = torch.clamp(restored, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
if True:
save_img(os.path.join(out_dir, os.path.splitext(os.path.split(inp_path)[-1])[0] + '.png'), img_as_ubyte(restored))