Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
import cv2 | |
import utils | |
import numpy as np | |
import torch | |
from PIL import Image | |
from utils import convert_state_dict | |
from models import restormer_arch | |
from data.preprocess.crop_merge_image import stride_integral | |
sys.path.append("./data/MBD/") | |
from data.MBD.infer import net1_net2_infer_single_im | |
def dewarp_prompt(img): | |
mask = net1_net2_infer_single_im(img, "data/MBD/checkpoint/mbd.pkl") | |
base_coord = utils.getBasecoord(256, 256) / 256 | |
img[mask == 0] = 0 | |
mask = cv2.resize(mask, (256, 256)) / 255 | |
return img, np.concatenate((base_coord, np.expand_dims(mask, -1)), -1) | |
def deshadow_prompt(img): | |
h, w = img.shape[:2] | |
# img = cv2.resize(img,(128,128)) | |
img = cv2.resize(img, (1024, 1024)) | |
rgb_planes = cv2.split(img) | |
result_planes = [] | |
result_norm_planes = [] | |
bg_imgs = [] | |
for plane in rgb_planes: | |
dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8)) | |
bg_img = cv2.medianBlur(dilated_img, 21) | |
bg_imgs.append(bg_img) | |
diff_img = 255 - cv2.absdiff(plane, bg_img) | |
norm_img = cv2.normalize( | |
diff_img, | |
None, | |
alpha=0, | |
beta=255, | |
norm_type=cv2.NORM_MINMAX, | |
dtype=cv2.CV_8UC1, | |
) | |
result_planes.append(diff_img) | |
result_norm_planes.append(norm_img) | |
bg_imgs = cv2.merge(bg_imgs) | |
bg_imgs = cv2.resize(bg_imgs, (w, h)) | |
# result = cv2.merge(result_planes) | |
result_norm = cv2.merge(result_norm_planes) | |
result_norm[result_norm == 0] = 1 | |
shadow_map = np.clip( | |
img.astype(float) / result_norm.astype(float) * 255, 0, 255 | |
).astype(np.uint8) | |
shadow_map = cv2.resize(shadow_map, (w, h)) | |
shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_BGR2GRAY) | |
shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_GRAY2BGR) | |
# return shadow_map | |
return bg_imgs | |
def deblur_prompt(img): | |
x = cv2.Sobel(img, cv2.CV_16S, 1, 0) | |
y = cv2.Sobel(img, cv2.CV_16S, 0, 1) | |
absX = cv2.convertScaleAbs(x) # 转回uint8 | |
absY = cv2.convertScaleAbs(y) | |
high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0) | |
high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY) | |
high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_GRAY2BGR) | |
return high_frequency | |
def appearance_prompt(img): | |
h, w = img.shape[:2] | |
# img = cv2.resize(img,(128,128)) | |
img = cv2.resize(img, (1024, 1024)) | |
rgb_planes = cv2.split(img) | |
result_planes = [] | |
result_norm_planes = [] | |
for plane in rgb_planes: | |
dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8)) | |
bg_img = cv2.medianBlur(dilated_img, 21) | |
diff_img = 255 - cv2.absdiff(plane, bg_img) | |
norm_img = cv2.normalize( | |
diff_img, | |
None, | |
alpha=0, | |
beta=255, | |
norm_type=cv2.NORM_MINMAX, | |
dtype=cv2.CV_8UC1, | |
) | |
result_planes.append(diff_img) | |
result_norm_planes.append(norm_img) | |
result_norm = cv2.merge(result_norm_planes) | |
result_norm = cv2.resize(result_norm, (w, h)) | |
return result_norm | |
def binarization_promptv2(img): | |
result, thresh = utils.SauvolaModBinarization(img) | |
thresh = thresh.astype(np.uint8) | |
result[result > 155] = 255 | |
result[result <= 155] = 0 | |
x = cv2.Sobel(img, cv2.CV_16S, 1, 0) | |
y = cv2.Sobel(img, cv2.CV_16S, 0, 1) | |
absX = cv2.convertScaleAbs(x) # 转回uint8 | |
absY = cv2.convertScaleAbs(y) | |
high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0) | |
high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY) | |
return np.concatenate( | |
( | |
np.expand_dims(thresh, -1), | |
np.expand_dims(high_frequency, -1), | |
np.expand_dims(result, -1), | |
), | |
-1, | |
) | |
def dewarping(model, im_org, device): | |
INPUT_SIZE = 256 | |
im_masked, prompt_org = dewarp_prompt(im_org.copy()) | |
h, w = im_masked.shape[:2] | |
im_masked = im_masked.copy() | |
im_masked = cv2.resize(im_masked, (INPUT_SIZE, INPUT_SIZE)) | |
im_masked = im_masked / 255.0 | |
im_masked = torch.from_numpy(im_masked.transpose(2, 0, 1)).unsqueeze(0) | |
im_masked = im_masked.float().to(device) | |
prompt = torch.from_numpy(prompt_org.transpose(2, 0, 1)).unsqueeze(0) | |
prompt = prompt.float().to(device) | |
in_im = torch.cat((im_masked, prompt), dim=1) | |
# inference | |
base_coord = utils.getBasecoord(INPUT_SIZE, INPUT_SIZE) / INPUT_SIZE | |
model = model.float() | |
with torch.no_grad(): | |
pred = model(in_im) | |
pred = pred[0][:2].permute(1, 2, 0).cpu().numpy() | |
pred = pred + base_coord | |
## smooth | |
for i in range(15): | |
pred = cv2.blur(pred, (3, 3), borderType=cv2.BORDER_REPLICATE) | |
pred = cv2.resize(pred, (w, h)) * (w, h) | |
pred = pred.astype(np.float32) | |
out_im = cv2.remap(im_org, pred[:, :, 0], pred[:, :, 1], cv2.INTER_LINEAR) | |
prompt_org = (prompt_org * 255).astype(np.uint8) | |
prompt_org = cv2.resize(prompt_org, im_org.shape[:2][::-1]) | |
return prompt_org[:, :, 0], prompt_org[:, :, 1], prompt_org[:, :, 2], out_im | |
def appearance(model, im_org, device): | |
MAX_SIZE = 1600 | |
# obtain im and prompt | |
h, w = im_org.shape[:2] | |
prompt = appearance_prompt(im_org) | |
in_im = np.concatenate((im_org, prompt), -1) | |
# constrain the max resolution | |
if max(w, h) < MAX_SIZE: | |
in_im, padding_h, padding_w = stride_integral(in_im, 8) | |
else: | |
in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE)) | |
# normalize | |
in_im = in_im / 255.0 | |
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) | |
# inference | |
in_im = in_im.half().to(device) | |
model = model.half() | |
with torch.no_grad(): | |
pred = model(in_im) | |
pred = torch.clamp(pred, 0, 1) | |
pred = pred[0].permute(1, 2, 0).cpu().numpy() | |
pred = (pred * 255).astype(np.uint8) | |
if max(w, h) < MAX_SIZE: | |
out_im = pred[padding_h:, padding_w:] | |
else: | |
pred[pred == 0] = 1 | |
shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype( | |
float | |
) / pred.astype(float) | |
shadow_map = cv2.resize(shadow_map, (w, h)) | |
shadow_map[shadow_map == 0] = 0.00001 | |
out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8) | |
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im | |
def deshadowing(model, im_org, device): | |
MAX_SIZE = 1600 | |
# obtain im and prompt | |
h, w = im_org.shape[:2] | |
prompt = deshadow_prompt(im_org) | |
in_im = np.concatenate((im_org, prompt), -1) | |
# constrain the max resolution | |
if max(w, h) < MAX_SIZE: | |
in_im, padding_h, padding_w = stride_integral(in_im, 8) | |
else: | |
in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE)) | |
# normalize | |
in_im = in_im / 255.0 | |
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) | |
# inference | |
in_im = in_im.half().to(device) | |
model = model.half() | |
with torch.no_grad(): | |
pred = model(in_im) | |
pred = torch.clamp(pred, 0, 1) | |
pred = pred[0].permute(1, 2, 0).cpu().numpy() | |
pred = (pred * 255).astype(np.uint8) | |
if max(w, h) < MAX_SIZE: | |
out_im = pred[padding_h:, padding_w:] | |
else: | |
pred[pred == 0] = 1 | |
shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype( | |
float | |
) / pred.astype(float) | |
shadow_map = cv2.resize(shadow_map, (w, h)) | |
shadow_map[shadow_map == 0] = 0.00001 | |
out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8) | |
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im | |
def deblurring(model, im_org, device): | |
# setup image | |
in_im, padding_h, padding_w = stride_integral(im_org, 8) | |
prompt = deblur_prompt(in_im) | |
in_im = np.concatenate((in_im, prompt), -1) | |
in_im = in_im / 255.0 | |
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) | |
in_im = in_im.half().to(device) | |
# inference | |
model.to(device) | |
model.eval() | |
model = model.half() | |
with torch.no_grad(): | |
pred = model(in_im) | |
pred = torch.clamp(pred, 0, 1) | |
pred = pred[0].permute(1, 2, 0).cpu().numpy() | |
pred = (pred * 255).astype(np.uint8) | |
out_im = pred[padding_h:, padding_w:] | |
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im | |
def binarization(model, im_org, device): | |
im, padding_h, padding_w = stride_integral(im_org, 8) | |
prompt = binarization_promptv2(im) | |
h, w = im.shape[:2] | |
in_im = np.concatenate((im, prompt), -1) | |
in_im = in_im / 255.0 | |
in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0) | |
in_im = in_im.to(device) | |
model = model.half() | |
in_im = in_im.half() | |
with torch.no_grad(): | |
pred = model(in_im) | |
pred = pred[:, :2, :, :] | |
pred = torch.max(torch.softmax(pred, 1), 1)[1] | |
pred = pred[0].cpu().numpy() | |
pred = (pred * 255).astype(np.uint8) | |
pred = cv2.resize(pred, (w, h)) | |
out_im = pred[padding_h:, padding_w:] | |
return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im | |
def model_init(model_path, device): | |
# prepare model | |
model = restormer_arch.Restormer( | |
inp_channels=6, | |
out_channels=3, | |
dim=48, | |
num_blocks=[2, 3, 3, 4], | |
num_refinement_blocks=4, | |
heads=[1, 2, 4, 8], | |
ffn_expansion_factor=2.66, | |
bias=False, | |
LayerNorm_type="WithBias", | |
dual_pixel_task=True, | |
) | |
if device == "cpu": | |
state = convert_state_dict( | |
torch.load(model_path, map_location="cpu")["model_state"] | |
) | |
else: | |
state = convert_state_dict( | |
torch.load(model_path, map_location="cuda:0")["model_state"] | |
) | |
model.load_state_dict(state) | |
model.eval() | |
model = model.to(device) | |
return model | |
def resize(image, max_size): | |
h, w = image.shape[:2] | |
if max(h, w) > max_size: | |
if h > w: | |
h_new = max_size | |
w_new = int(w * h_new / h) | |
else: | |
w_new = max_size | |
h_new = int(h * w_new / w) | |
pil_image = Image.fromarray(image) | |
pil_image = pil_image.resize((w_new, h_new), Image.Resampling.LANCZOS) | |
image = np.array(pil_image) | |
return image | |
def inference_one_image(model, image, tasks, device): | |
# image should be in BGR format | |
if "dewarping" in tasks: | |
*_, image = dewarping(model, image, device) | |
# if only dewarping return here | |
if len(tasks) == 1 and "dewarping" in tasks: | |
return image | |
image = resize(image, 1536) | |
if "deshadowing" in tasks: | |
*_, image = deshadowing(model, image, device) | |
if "appearance" in tasks: | |
*_, image = appearance(model, image, device) | |
if "deblurring" in tasks: | |
*_, image = deblurring(model, image, device) | |
if "binarization" in tasks: | |
*_, image = binarization(model, image, device) | |
return image | |