File size: 2,995 Bytes
240c20c
 
 
 
 
 
 
 
 
 
 
 
 
 
2718a79
240c20c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2718a79
 
 
240c20c
2718a79
 
 
 
b7546a7
240c20c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2718a79
240c20c
 
 
 
 
 
 
 
 
 
b79aac9
2718a79
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
from PIL import Image
import warnings
import gradio as gr

from model import DocGeoNet
from seg import U2NETP
import glob


warnings.filterwarnings('ignore')

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.msk = U2NETP(3, 1)
        self.DocTr = DocGeoNet()

    def forward(self, x):
        msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
        msk = (msk > 0.5).float()
        x = msk * x

        _, _, bm = self.DocTr(x)
        bm = (2 * (bm / 255.) - 1) * 0.99

        return bm

def reload_seg_model(model, path=""):
    if not bool(path):
        return model
    else:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(path, map_location='cpu')
        pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        return model

def reload_rec_model(model, path=""):
    if not bool(path):
        return model
    else:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(path, map_location='cpu')
        pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        return model

def rec(input_image):
    seg_model_path = './model_pretrained/preprocess.pth'
    rec_model_path = './model_pretrained/DocGeoNet.pth'

    net = Net()
    reload_rec_model(net.DocTr, rec_model_path)
    reload_seg_model(net.msk, seg_model_path)
    net.eval()

    im_ori = np.array(input_image)[:, :, :3] / 255.  # read image 0-255 to 0-1
    h, w, _ = im_ori.shape
    im = cv2.resize(im_ori, (256, 256))
    im = im.transpose(2, 0, 1)
    im = torch.from_numpy(im).float().unsqueeze(0)

    with torch.no_grad():
        bm = net(im)
        bm = bm.cpu()

        bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))  # x flow
        bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))  # y flow
        bm0 = cv2.blur(bm0, (3, 3))
        bm1 = cv2.blur(bm1, (3, 3))
        lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)  # h * w * 2
        out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
        img_rec = ((out[0] * 255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8)

        # Convert from BGR to RGB
        img_rec = cv2.cvtColor(img_rec, cv2.COLOR_BGR2RGB)
        return Image.fromarray(img_rec)


demo_img_files = glob.glob('./distorted/*.[jJ][pP][gG]') + glob.glob('./distorted/*.[pP][nN][gG]')

# Gradio Interface
input_image = gr.inputs.Image()
output_image = gr.outputs.Image(type='pil')



iface = gr.Interface(fn=rec, inputs=input_image, outputs=output_image, title="DocGeoNet",examples=demo_img_files)

#iface.launch(server_port=8821, server_name="0.0.0.0")
iface.launch()