import os import cv2 import utils import argparse import numpy as np import torch from utils import convert_state_dict from models import restormer_arch from data.preprocess.crop_merge_image import stride_integral os.sys.path.append('./data/MBD/') from data.MBD.infer import net1_net2_infer_single_im def dewarp_prompt(img): mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl') base_coord = utils.getBasecoord(256,256)/256 img[mask==0]=0 mask = cv2.resize(mask,(256,256))/255 return img,np.concatenate((base_coord,np.expand_dims(mask,-1)),-1) def deshadow_prompt(img): h,w = img.shape[:2] # img = cv2.resize(img,(128,128)) img = cv2.resize(img,(1024,1024)) rgb_planes = cv2.split(img) result_planes = [] result_norm_planes = [] bg_imgs = [] for plane in rgb_planes: dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8)) bg_img = cv2.medianBlur(dilated_img, 21) bg_imgs.append(bg_img) diff_img = 255 - cv2.absdiff(plane, bg_img) norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1) result_planes.append(diff_img) result_norm_planes.append(norm_img) bg_imgs = cv2.merge(bg_imgs) bg_imgs = cv2.resize(bg_imgs,(w,h)) # result = cv2.merge(result_planes) result_norm = cv2.merge(result_norm_planes) result_norm[result_norm==0]=1 shadow_map = np.clip(img.astype(float)/result_norm.astype(float)*255,0,255).astype(np.uint8) shadow_map = cv2.resize(shadow_map,(w,h)) shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_BGR2GRAY) shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_GRAY2BGR) # return shadow_map return bg_imgs def deblur_prompt(img): x = cv2.Sobel(img,cv2.CV_16S,1,0) y = cv2.Sobel(img,cv2.CV_16S,0,1) absX = cv2.convertScaleAbs(x) # 转回uint8 absY = cv2.convertScaleAbs(y) high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0) high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY) high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR) return high_frequency def appearance_prompt(img): h,w = img.shape[:2] # img = cv2.resize(img,(128,128)) img = cv2.resize(img,(1024,1024)) rgb_planes = cv2.split(img) result_planes = [] result_norm_planes = [] for plane in rgb_planes: dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8)) bg_img = cv2.medianBlur(dilated_img, 21) diff_img = 255 - cv2.absdiff(plane, bg_img) norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1) result_planes.append(diff_img) result_norm_planes.append(norm_img) result_norm = cv2.merge(result_norm_planes) result_norm = cv2.resize(result_norm,(w,h)) return result_norm def binarization_promptv2(img): result,thresh = utils.SauvolaModBinarization(img) thresh = thresh.astype(np.uint8) result[result>155]=255 result[result<=155]=0 x = cv2.Sobel(img,cv2.CV_16S,1,0) y = cv2.Sobel(img,cv2.CV_16S,0,1) absX = cv2.convertScaleAbs(x) # 转回uint8 absY = cv2.convertScaleAbs(y) high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0) high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY) return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1) def dewarping(model,im_path): INPUT_SIZE=256 im_org = cv2.imread(im_path) im_masked, prompt_org = dewarp_prompt(im_org.copy()) h,w = im_masked.shape[:2] im_masked = im_masked.copy() im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE)) im_masked = im_masked / 255.0 im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0) im_masked = im_masked.float().to(DEVICE) prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0) prompt = prompt.float().to(DEVICE) in_im = torch.cat((im_masked,prompt),dim=1) # inference base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE model = model.float() with torch.no_grad(): pred = model(in_im) pred = pred[0][:2].permute(1,2,0).cpu().numpy() pred = pred+base_coord ## smooth for i in range(15): pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE) pred = cv2.resize(pred,(w,h))*(w,h) pred = pred.astype(np.float32) out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR) prompt_org = (prompt_org*255).astype(np.uint8) prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1]) return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im def appearance(model,im_path): MAX_SIZE=1600 # obtain im and prompt im_org = cv2.imread(im_path) h,w = im_org.shape[:2] prompt = appearance_prompt(im_org) in_im = np.concatenate((im_org,prompt),-1) # constrain the max resolution if max(w,h) < MAX_SIZE: in_im,padding_h,padding_w = stride_integral(in_im,8) else: in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE)) # normalize in_im = in_im / 255.0 in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0) # inference in_im = in_im.half().to(DEVICE) model = model.half() with torch.no_grad(): pred = model(in_im) pred = torch.clamp(pred,0,1) pred = pred[0].permute(1,2,0).cpu().numpy() pred = (pred*255).astype(np.uint8) if max(w,h) < MAX_SIZE: out_im = pred[padding_h:,padding_w:] else: pred[pred==0] = 1 shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float) shadow_map = cv2.resize(shadow_map,(w,h)) shadow_map[shadow_map==0]=0.00001 out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8) return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im def deshadowing(model,im_path): MAX_SIZE=1600 # obtain im and prompt im_org = cv2.imread(im_path) h,w = im_org.shape[:2] prompt = deshadow_prompt(im_org) in_im = np.concatenate((im_org,prompt),-1) # constrain the max resolution if max(w,h) < MAX_SIZE: in_im,padding_h,padding_w = stride_integral(in_im,8) else: in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE)) # normalize in_im = in_im / 255.0 in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0) # inference in_im = in_im.half().to(DEVICE) model = model.half() with torch.no_grad(): pred = model(in_im) pred = torch.clamp(pred,0,1) pred = pred[0].permute(1,2,0).cpu().numpy() pred = (pred*255).astype(np.uint8) if max(w,h) < MAX_SIZE: out_im = pred[padding_h:,padding_w:] else: pred[pred==0]=1 shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float) shadow_map = cv2.resize(shadow_map,(w,h)) shadow_map[shadow_map==0]=0.00001 out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8) return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im def deblurring(model,im_path): # setup image im_org = cv2.imread(im_path) in_im,padding_h,padding_w = stride_integral(im_org,8) prompt = deblur_prompt(in_im) in_im = np.concatenate((in_im,prompt),-1) in_im = in_im / 255.0 in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0) in_im = in_im.half().to(DEVICE) # inference model.to(DEVICE) model.eval() model = model.half() with torch.no_grad(): pred = model(in_im) pred = torch.clamp(pred,0,1) pred = pred[0].permute(1,2,0).cpu().numpy() pred = (pred*255).astype(np.uint8) out_im = pred[padding_h:,padding_w:] return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im def binarization(model,im_path): im_org = cv2.imread(im_path) im,padding_h,padding_w = stride_integral(im_org,8) prompt = binarization_promptv2(im) h,w = im.shape[:2] in_im = np.concatenate((im,prompt),-1) in_im = in_im / 255.0 in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0) in_im = in_im.to(DEVICE) model = model.half() in_im = in_im.half() with torch.no_grad(): pred = model(in_im) pred = pred[:,:2,:,:] pred = torch.max(torch.softmax(pred,1),1)[1] pred = pred[0].cpu().numpy() pred = (pred*255).astype(np.uint8) pred = cv2.resize(pred,(w,h)) out_im = pred[padding_h:,padding_w:] return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im def get_args(): parser = argparse.ArgumentParser(description='Params') parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint') parser.add_argument('--im_path', nargs='?', type=str, default='./distorted/', help='Path of input document image') parser.add_argument('--out_folder', nargs='?', type=str, default='./restorted/', help='Folder of the output images') parser.add_argument('--task', nargs='?', type=str, default='dewarping', help='task that need to be executed') parser.add_argument('--save_dtsprompt', nargs='?', type=int, default=0, help='Width of the input image') args = parser.parse_args() possible_tasks = ['dewarping','deshadowing','appearance','deblurring','binarization','end2end'] assert args.task in possible_tasks, 'Unsupported task, task must be one of '+', '.join(possible_tasks) return args def model_init(args): # prepare model model = restormer_arch.Restormer( inp_channels=6, out_channels=3, dim = 48, num_blocks = [2,3,3,4], num_refinement_blocks = 4, heads = [1,2,4,8], ffn_expansion_factor = 2.66, bias = False, LayerNorm_type = 'WithBias', dual_pixel_task = True ) if DEVICE.type == 'cpu': state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state']) else: state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state']) model.load_state_dict(state) model.eval() model = model.to(DEVICE) return model def inference_one_im(model,im_path,task): if task=='dewarping': prompt1,prompt2,prompt3,restorted = dewarping(model,im_path) elif task=='deshadowing': prompt1,prompt2,prompt3,restorted = deshadowing(model,im_path) elif task=='appearance': prompt1,prompt2,prompt3,restorted = appearance(model,im_path) elif task=='deblurring': prompt1,prompt2,prompt3,restorted = deblurring(model,im_path) elif task=='binarization': prompt1,prompt2,prompt3,restorted = binarization(model,im_path) elif task=='end2end': prompt1,prompt2,prompt3,restorted = dewarping(model,im_path) cv2.imwrite('restorted/step1.jpg',restorted) prompt1,prompt2,prompt3,restorted = deshadowing(model,'restorted/step1.jpg') cv2.imwrite('restorted/step2.jpg',restorted) prompt1,prompt2,prompt3,restorted = appearance(model,'restorted/step2.jpg') # os.remove('restorted/step1.jpg') # os.remove('restorted/step2.jpg') return prompt1,prompt2,prompt3,restorted if __name__ == '__main__': ## model init DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') args = get_args() model = model_init(args) ## inference prompt1,prompt2,prompt3,restorted = inference_one_im(model,args.im_path,args.task) ## results saving im_name = os.path.split(args.im_path)[-1] im_format = '.'+im_name.split('.')[-1] save_path = os.path.join(args.out_folder,im_name.replace(im_format,'_'+args.task+im_format)) cv2.imwrite(save_path,restorted) if args.save_dtsprompt: cv2.imwrite(save_path.replace(im_format,'_prompt1'+im_format),prompt1) cv2.imwrite(save_path.replace(im_format,'_prompt2'+im_format),prompt2) cv2.imwrite(save_path.replace(im_format,'_prompt3'+im_format),prompt3)