HaoFeng2019 commited on
Commit
b7546a7
1 Parent(s): b79aac9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -16
app.py CHANGED
@@ -12,7 +12,6 @@ from model import DocGeoNet
12
  from seg import U2NETP
13
  import glob
14
 
15
-
16
  warnings.filterwarnings('ignore')
17
 
18
  class Net(nn.Module):
@@ -53,15 +52,18 @@ def reload_rec_model(model, path=""):
53
  model.load_state_dict(model_dict)
54
  return model
55
 
56
- def rec(input_image):
57
- seg_model_path = './model_pretrained/preprocess.pth'
58
- rec_model_path = './model_pretrained/DocGeoNet.pth'
 
 
59
 
60
- net = Net()
61
- reload_rec_model(net.DocTr, rec_model_path)
62
- reload_seg_model(net.msk, seg_model_path)
63
- net.eval()
64
 
 
 
 
65
  im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1
66
  h, w, _ = im_ori.shape
67
  im = cv2.resize(im_ori, (256, 256))
@@ -78,7 +80,7 @@ def rec(input_image):
78
  bm1 = cv2.blur(bm1, (3, 3))
79
  lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
80
  out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
81
- img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8)
82
 
83
  # Convert from BGR to RGB
84
  img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
@@ -90,10 +92,3 @@ demo_img_files = glob.glob('./distorted/*.[jJ][pP][gG]') + glob.glob('./distorte
90
  # Gradio Interface
91
  input_image = gr.inputs.Image()
92
  output_image = gr.outputs.Image(type='pil')
93
-
94
-
95
-
96
- iface = gr.Interface(fn=rec, inputs=input_image, outputs=output_image, title="DocGeoNet",examples=demo_img_files)
97
-
98
- #iface.launch(server_port=8821, server_name="0.0.0.0")
99
- iface.launch()
 
12
  from seg import U2NETP
13
  import glob
14
 
 
15
  warnings.filterwarnings('ignore')
16
 
17
  class Net(nn.Module):
 
52
  model.load_state_dict(model_dict)
53
  return model
54
 
55
+ net = Net()
56
+ seg_model_path = './model_pretrained/preprocess.pth'
57
+ rec_model_path = './model_pretrained/DocGeoNet.pth'
58
+ reload_rec_model(net.DocTr, rec_model_path)
59
+ reload_seg_model(net.msk, seg_model_path)
60
 
61
+ # Compile models (assuming PyTorch 2.0)
62
+ net = torch.compile(net)
 
 
63
 
64
+ net.eval()
65
+
66
+ def rec(input_image):
67
  im_ori = np.array(input_image)[:, :, :3] / 255. # read image 0-255 to 0-1
68
  h, w, _ = im_ori.shape
69
  im = cv2.resize(im_ori, (256, 256))
 
80
  bm1 = cv2.blur(bm1, (3, 3))
81
  lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
82
  out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
83
+ img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1].astype(np.uint8)
84
 
85
  # Convert from BGR to RGB
86
  img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
 
92
  # Gradio Interface
93
  input_image = gr.inputs.Image()
94
  output_image = gr.outputs.Image(type='pil')