HaoFeng2019 commited on
Commit
ea9f27f
1 Parent(s): dee3268

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -3
app.py CHANGED
@@ -1,7 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
 
1
+ #origin
2
+
3
+ from seg import U2NETP
4
+ from GeoTr import GeoTr
5
+ from IllTr import IllTr
6
+ from inference_ill import rec_ill
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import skimage.io as io
12
+ import numpy as np
13
+ import cv2
14
+ #import glob
15
+ import os
16
+ from PIL import Image
17
+ #import argparse
18
+ import warnings
19
+ warnings.filterwarnings('ignore')
20
+
21
+
22
+
23
+
24
+
25
  import gradio as gr
26
 
 
 
27
 
28
+ class GeoTr_Seg(nn.Module):
29
+ def __init__(self):
30
+ super(GeoTr_Seg, self).__init__()
31
+ self.msk = U2NETP(3, 1)
32
+ self.GeoTr = GeoTr(num_attn_layers=6)
33
+
34
+ def forward(self, x):
35
+ msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
36
+ msk = (msk > 0.5).float()
37
+ x = msk * x
38
+
39
+ bm = self.GeoTr(x)
40
+ bm = (2 * (bm / 286.8) - 1) * 0.99
41
+
42
+ return bm
43
+
44
+
45
+ def reload_model(model, path=""):
46
+ if not bool(path):
47
+ return model
48
+ else:
49
+ model_dict = model.state_dict()
50
+ pretrained_dict = torch.load(path, map_location='cuda:0')
51
+ print(len(pretrained_dict.keys()))
52
+ pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
53
+ print(len(pretrained_dict.keys()))
54
+ model_dict.update(pretrained_dict)
55
+ model.load_state_dict(model_dict)
56
+
57
+ return model
58
+
59
+
60
+ def reload_segmodel(model, path=""):
61
+ if not bool(path):
62
+ return model
63
+ else:
64
+ model_dict = model.state_dict()
65
+ pretrained_dict = torch.load(path, map_location='cuda:0')
66
+ print(len(pretrained_dict.keys()))
67
+ pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
68
+ print(len(pretrained_dict.keys()))
69
+ model_dict.update(pretrained_dict)
70
+ model.load_state_dict(model_dict)
71
+
72
+ return model
73
+
74
+
75
+ def rec(opt):
76
+ # print(torch.__version__) # 1.5.1
77
+ img_list = os.listdir(opt.distorrted_path) # distorted images list
78
+
79
+ if not os.path.exists(opt.gsave_path): # create save path
80
+ os.mkdir(opt.gsave_path)
81
+ if not os.path.exists(opt.isave_path): # create save path
82
+ os.mkdir(opt.isave_path)
83
+
84
+ GeoTr_Seg_model = GeoTr_Seg().cuda()
85
+ # reload segmentation model
86
+ reload_segmodel(GeoTr_Seg_model.msk, opt.Seg_path)
87
+ # reload geometric unwarping model
88
+ reload_model(GeoTr_Seg_model.GeoTr, opt.GeoTr_path)
89
+
90
+ IllTr_model = IllTr().cuda()
91
+ # reload illumination rectification model
92
+ reload_model(IllTr_model, opt.IllTr_path)
93
+
94
+ # To eval mode
95
+ GeoTr_Seg_model.eval()
96
+ IllTr_model.eval()
97
+
98
+ for img_path in img_list:
99
+ name = img_path.split('.')[-2] # image name
100
+
101
+ img_path = opt.distorrted_path + img_path # read image and to tensor
102
+ im_ori = np.array(Image.open(img_path))[:, :, :3] / 255.
103
+ h, w, _ = im_ori.shape
104
+ im = cv2.resize(im_ori, (288, 288))
105
+ im = im.transpose(2, 0, 1)
106
+ im = torch.from_numpy(im).float().unsqueeze(0)
107
+
108
+ with torch.no_grad():
109
+ # geometric unwarping
110
+ bm = GeoTr_Seg_model(im.cuda())
111
+ bm = bm.cpu()
112
+ bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
113
+ bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
114
+ bm0 = cv2.blur(bm0, (3, 3))
115
+ bm1 = cv2.blur(bm1, (3, 3))
116
+ lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
117
+
118
+ out = F.grid_sample(torch.from_numpy(im_ori).permute(2,0,1).unsqueeze(0).float(), lbl, align_corners=True)
119
+ img_geo = ((out[0]*255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8)
120
+ cv2.imwrite(opt.gsave_path + name + '_geo' + '.png', img_geo) # save
121
+
122
+ # illumination rectification
123
+ if opt.ill_rec:
124
+ ill_savep = opt.isave_path + name + '_ill' + '.png'
125
+ rec_ill(IllTr_model, img_geo, saveRecPath=ill_savep)
126
+
127
+ print('Done: ', img_path)
128
+
129
+
130
+
131
+
132
+
133
+
134
+ def process_image(input_image):
135
+ GeoTr_Seg_model = GeoTr_Seg().cuda()
136
+ reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
137
+ reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
138
+
139
+ IllTr_model = IllTr().cuda()
140
+ reload_model(IllTr_model, './model_pretrained/illtr.pth')
141
+
142
+ GeoTr_Seg_model.eval()
143
+ IllTr_model.eval()
144
+
145
+ im_ori = np.array(input_image)[:, :, :3] / 255.
146
+ h, w, _ = im_ori.shape
147
+ im = cv2.resize(im_ori, (288, 288))
148
+ im = im.transpose(2, 0, 1)
149
+ im = torch.from_numpy(im).float().unsqueeze(0)
150
+
151
+ with torch.no_grad():
152
+ bm = GeoTr_Seg_model(im.cuda())
153
+ bm = bm.cpu()
154
+ bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
155
+ bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
156
+ bm0 = cv2.blur(bm0, (3, 3))
157
+ bm1 = cv2.blur(bm1, (3, 3))
158
+ lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)
159
+
160
+ out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
161
+ img_geo = ((out[0] * 255).permute(1, 2, 0).numpy()).astype(np.uint8)
162
+
163
+ ill_rec=False
164
+
165
+ if ill_rec:
166
+ img_ill = rec_ill(IllTr_model, img_geo)
167
+ return Image.fromarray(img_ill)
168
+ else:
169
+ return Image.fromarray(img_geo)
170
+
171
+ # Define Gradio interface
172
+ input_image = gr.inputs.Image()
173
+ output_image = gr.outputs.Image(type='pil')
174
+
175
+
176
+ iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="Image Correction")
177
  iface.launch()
178
+