HaoFeng2019 commited on
Commit
2718a79
1 Parent(s): b7546a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -12,6 +12,7 @@ from model import DocGeoNet
12
  from seg import U2NETP
13
  import glob
14
 
 
15
  warnings.filterwarnings('ignore')
16
 
17
  class Net(nn.Module):
@@ -52,18 +53,15 @@ def reload_rec_model(model, path=""):
52
  model.load_state_dict(model_dict)
53
  return model
54
 
55
- net = Net()
56
- seg_model_path = './model_pretrained/preprocess.pth'
57
- rec_model_path = './model_pretrained/DocGeoNet.pth'
58
- reload_rec_model(net.DocTr, rec_model_path)
59
- reload_seg_model(net.msk, seg_model_path)
60
-
61
- # Compile models (assuming PyTorch 2.0)
62
- net = torch.compile(net)
63
 
64
- net.eval()
 
 
 
65
 
66
- def rec(input_image):
67
  im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1
68
  h, w, _ = im_ori.shape
69
  im = cv2.resize(im_ori, (256, 256))
@@ -80,7 +78,7 @@ def rec(input_image):
80
  bm1 = cv2.blur(bm1, (3, 3))
81
  lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
82
  out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
83
- img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1].astype(np.uint8)
84
 
85
  # Convert from BGR to RGB
86
  img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
@@ -92,3 +90,10 @@ demo_img_files = glob.glob('./distorted/*.[jJ][pP][gG]') + glob.glob('./distorte
92
  # Gradio Interface
93
  input_image = gr.inputs.Image()
94
  output_image = gr.outputs.Image(type='pil')
 
 
 
 
 
 
 
 
12
  from seg import U2NETP
13
  import glob
14
 
15
+
16
  warnings.filterwarnings('ignore')
17
 
18
  class Net(nn.Module):
 
53
  model.load_state_dict(model_dict)
54
  return model
55
 
56
+ def rec(input_image):
57
+ seg_model_path = './model_pretrained/preprocess.pth'
58
+ rec_model_path = './model_pretrained/DocGeoNet.pth'
 
 
 
 
 
59
 
60
+ net = Net()
61
+ reload_rec_model(net.DocTr, rec_model_path)
62
+ reload_seg_model(net.msk, seg_model_path)
63
+ net.eval()
64
 
 
65
  im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1
66
  h, w, _ = im_ori.shape
67
  im = cv2.resize(im_ori, (256, 256))
 
78
  bm1 = cv2.blur(bm1, (3, 3))
79
  lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
80
  out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
81
+ img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8)
82
 
83
  # Convert from BGR to RGB
84
  img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
 
90
  # Gradio Interface
91
  input_image = gr.inputs.Image()
92
  output_image = gr.outputs.Image(type='pil')
93
+
94
+
95
+
96
+ iface = gr.Interface(fn=rec, inputs=input_image, outputs=output_image, title="DocGeoNet",examples=demo_img_files)
97
+
98
+ #iface.launch(server_port=8821, server_name="0.0.0.0")
99
+ iface.launch()