Spaces:
Runtime error
Runtime error
Harsimran19
commited on
Commit
•
4331eba
1
Parent(s):
6d92fcd
Upload 4 files
Browse files- app.py +40 -0
- models/document_model/config.json +63 -0
- preprocess.py +83 -0
- requirements.txt +82 -0
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytesseract
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
from transformers import LayoutLMForSequenceClassification
|
5 |
+
from preprocess import apply_ocr,encode_example
|
6 |
+
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"
|
9 |
+
model = LayoutLMForSequenceClassification.from_pretrained("models/document_model")
|
10 |
+
model.to(device)
|
11 |
+
classes=['questionnaire', 'memo', 'budget', 'file_folder', 'specification', 'invoice', 'resume',
|
12 |
+
'advertisement', 'news_article', 'email', 'scientific_publication', 'presentation',
|
13 |
+
'letter', 'form', 'handwritten', 'scientific_report']
|
14 |
+
|
15 |
+
|
16 |
+
def predict(image):
|
17 |
+
example = apply_ocr(image)
|
18 |
+
encoded_example = encode_example(example)
|
19 |
+
input_ids = torch.tensor(encoded_example['input_ids']).unsqueeze(0)
|
20 |
+
bbox = torch.tensor(encoded_example['bbox']).unsqueeze(0)
|
21 |
+
attention_mask = torch.tensor(encoded_example['attention_mask']).unsqueeze(0)
|
22 |
+
token_type_ids = torch.tensor(encoded_example['token_type_ids']).unsqueeze(0)
|
23 |
+
model.eval()
|
24 |
+
outputs=model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
25 |
+
classification_results = torch.softmax(outputs.logits, dim=1).tolist()[0]
|
26 |
+
max_prob_index = classification_results.index(max(classification_results))
|
27 |
+
predicted_class = classes[max_prob_index]
|
28 |
+
return predicted_class
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
title="Document Image Classification"
|
33 |
+
|
34 |
+
demo = gr.Interface(
|
35 |
+
fn=predict,
|
36 |
+
inputs=gr.inputs.Image(type="pil"),
|
37 |
+
outputs=gr.outputs.Textbox(label="Predicted Class"),
|
38 |
+
title=title,
|
39 |
+
)
|
40 |
+
demo.launch(share=True)
|
models/document_model/config.json
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "microsoft/layoutlm-base-uncased",
|
3 |
+
"architectures": [
|
4 |
+
"LayoutLMForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_dropout_prob": 0.1,
|
9 |
+
"hidden_size": 768,
|
10 |
+
"id2label": {
|
11 |
+
"0": "LABEL_0",
|
12 |
+
"1": "LABEL_1",
|
13 |
+
"2": "LABEL_2",
|
14 |
+
"3": "LABEL_3",
|
15 |
+
"4": "LABEL_4",
|
16 |
+
"5": "LABEL_5",
|
17 |
+
"6": "LABEL_6",
|
18 |
+
"7": "LABEL_7",
|
19 |
+
"8": "LABEL_8",
|
20 |
+
"9": "LABEL_9",
|
21 |
+
"10": "LABEL_10",
|
22 |
+
"11": "LABEL_11",
|
23 |
+
"12": "LABEL_12",
|
24 |
+
"13": "LABEL_13",
|
25 |
+
"14": "LABEL_14",
|
26 |
+
"15": "LABEL_15"
|
27 |
+
},
|
28 |
+
"initializer_range": 0.02,
|
29 |
+
"intermediate_size": 3072,
|
30 |
+
"label2id": {
|
31 |
+
"LABEL_0": 0,
|
32 |
+
"LABEL_1": 1,
|
33 |
+
"LABEL_10": 10,
|
34 |
+
"LABEL_11": 11,
|
35 |
+
"LABEL_12": 12,
|
36 |
+
"LABEL_13": 13,
|
37 |
+
"LABEL_14": 14,
|
38 |
+
"LABEL_15": 15,
|
39 |
+
"LABEL_2": 2,
|
40 |
+
"LABEL_3": 3,
|
41 |
+
"LABEL_4": 4,
|
42 |
+
"LABEL_5": 5,
|
43 |
+
"LABEL_6": 6,
|
44 |
+
"LABEL_7": 7,
|
45 |
+
"LABEL_8": 8,
|
46 |
+
"LABEL_9": 9
|
47 |
+
},
|
48 |
+
"layer_norm_eps": 1e-12,
|
49 |
+
"max_2d_position_embeddings": 1024,
|
50 |
+
"max_position_embeddings": 512,
|
51 |
+
"model_type": "layoutlm",
|
52 |
+
"num_attention_heads": 12,
|
53 |
+
"num_hidden_layers": 12,
|
54 |
+
"output_past": true,
|
55 |
+
"pad_token_id": 0,
|
56 |
+
"position_embedding_type": "absolute",
|
57 |
+
"problem_type": "single_label_classification",
|
58 |
+
"torch_dtype": "float32",
|
59 |
+
"transformers_version": "4.30.2",
|
60 |
+
"type_vocab_size": 2,
|
61 |
+
"use_cache": true,
|
62 |
+
"vocab_size": 30522
|
63 |
+
}
|
preprocess.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytesseract
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from transformers import LayoutLMTokenizer
|
5 |
+
|
6 |
+
|
7 |
+
pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"
|
8 |
+
tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
|
9 |
+
def normalize_box(box, width, height):
|
10 |
+
return [
|
11 |
+
int(1000 * (box[0] / width)),
|
12 |
+
int(1000 * (box[1] / height)),
|
13 |
+
int(1000 * (box[2] / width)),
|
14 |
+
int(1000 * (box[3] / height)),
|
15 |
+
]
|
16 |
+
|
17 |
+
def apply_ocr(image):
|
18 |
+
# get the image
|
19 |
+
# image = Image.open(example['image_path'])
|
20 |
+
|
21 |
+
width, height = image.size
|
22 |
+
example={}
|
23 |
+
# apply ocr to the image
|
24 |
+
ocr_df = pytesseract.image_to_data(image, output_type='data.frame')
|
25 |
+
float_cols = ocr_df.select_dtypes('float').columns
|
26 |
+
ocr_df = ocr_df.dropna().reset_index(drop=True)
|
27 |
+
ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
|
28 |
+
ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
|
29 |
+
ocr_df = ocr_df.dropna().reset_index(drop=True)
|
30 |
+
|
31 |
+
# get the words and actual (unnormalized) bounding boxes
|
32 |
+
#words = [word for word in ocr_df.text if str(word) != 'nan'])
|
33 |
+
words = list(ocr_df.text)
|
34 |
+
words = [str(w) for w in words]
|
35 |
+
coordinates = ocr_df[['left', 'top', 'width', 'height']]
|
36 |
+
actual_boxes = []
|
37 |
+
for idx, row in coordinates.iterrows():
|
38 |
+
x, y, w, h = tuple(row) # the row comes in (left, top, width, height) format
|
39 |
+
actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+width, top+height) to get the actual box
|
40 |
+
actual_boxes.append(actual_box)
|
41 |
+
|
42 |
+
# normalize the bounding boxes
|
43 |
+
boxes = []
|
44 |
+
for box in actual_boxes:
|
45 |
+
boxes.append(normalize_box(box, width, height))
|
46 |
+
|
47 |
+
# add as extra columns
|
48 |
+
assert len(words) == len(boxes)
|
49 |
+
example['words'] = words
|
50 |
+
example['bbox'] = boxes
|
51 |
+
return example
|
52 |
+
def encode_example(example, max_seq_length=512, pad_token_box=[0, 0, 0, 0]):
|
53 |
+
words = example['words']
|
54 |
+
normalized_word_boxes = example['bbox']
|
55 |
+
|
56 |
+
assert len(words) == len(normalized_word_boxes)
|
57 |
+
|
58 |
+
token_boxes = []
|
59 |
+
for word, box in zip(words, normalized_word_boxes):
|
60 |
+
word_tokens = tokenizer.tokenize(word)
|
61 |
+
token_boxes.extend([box] * len(word_tokens))
|
62 |
+
|
63 |
+
# Truncation of token_boxes
|
64 |
+
special_tokens_count = 2
|
65 |
+
if len(token_boxes) > max_seq_length - special_tokens_count:
|
66 |
+
token_boxes = token_boxes[: (max_seq_length - special_tokens_count)]
|
67 |
+
|
68 |
+
# add bounding boxes of cls + sep tokens
|
69 |
+
token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
|
70 |
+
|
71 |
+
encoding = tokenizer(' '.join(words), padding='max_length', truncation=True)
|
72 |
+
# Padding of token_boxes up the bounding boxes to the sequence length.
|
73 |
+
input_ids = tokenizer(' '.join(words), truncation=True)["input_ids"]
|
74 |
+
padding_length = max_seq_length - len(input_ids)
|
75 |
+
token_boxes += [pad_token_box] * padding_length
|
76 |
+
encoding['bbox'] = token_boxes
|
77 |
+
|
78 |
+
assert len(encoding['input_ids']) == max_seq_length
|
79 |
+
assert len(encoding['attention_mask']) == max_seq_length
|
80 |
+
assert len(encoding['token_type_ids']) == max_seq_length
|
81 |
+
assert len(encoding['bbox']) == max_seq_length
|
82 |
+
|
83 |
+
return encoding
|
requirements.txt
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.1.0
|
2 |
+
aiohttp==3.8.4
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==5.0.1
|
5 |
+
annotated-types==0.5.0
|
6 |
+
anyio==3.7.1
|
7 |
+
async-timeout==4.0.2
|
8 |
+
attrs==23.1.0
|
9 |
+
certifi==2023.5.7
|
10 |
+
charset-normalizer==3.2.0
|
11 |
+
click==8.1.6
|
12 |
+
colorama==0.4.6
|
13 |
+
contourpy==1.1.0
|
14 |
+
cycler==0.11.0
|
15 |
+
datasets==2.13.1
|
16 |
+
dill==0.3.6
|
17 |
+
exceptiongroup==1.1.2
|
18 |
+
fastapi==0.100.0
|
19 |
+
ffmpy==0.3.1
|
20 |
+
filelock==3.12.2
|
21 |
+
fonttools==4.41.0
|
22 |
+
frozenlist==1.4.0
|
23 |
+
fsspec==2023.6.0
|
24 |
+
gradio==3.37.0
|
25 |
+
gradio_client==0.2.10
|
26 |
+
h11==0.14.0
|
27 |
+
httpcore==0.17.3
|
28 |
+
httpx==0.24.1
|
29 |
+
huggingface-hub==0.16.4
|
30 |
+
idna==3.4
|
31 |
+
Jinja2==3.1.2
|
32 |
+
jsonschema==4.18.4
|
33 |
+
jsonschema-specifications==2023.7.1
|
34 |
+
kiwisolver==1.4.4
|
35 |
+
linkify-it-py==2.0.2
|
36 |
+
markdown-it-py==2.2.0
|
37 |
+
MarkupSafe==2.1.3
|
38 |
+
matplotlib==3.7.2
|
39 |
+
mdit-py-plugins==0.3.3
|
40 |
+
mdurl==0.1.2
|
41 |
+
mpmath==1.3.0
|
42 |
+
multidict==6.0.4
|
43 |
+
multiprocess==0.70.14
|
44 |
+
networkx==3.1
|
45 |
+
numpy==1.25.1
|
46 |
+
orjson==3.9.2
|
47 |
+
packaging==23.1
|
48 |
+
pandas==2.0.3
|
49 |
+
Pillow==10.0.0
|
50 |
+
pyarrow==12.0.1
|
51 |
+
pydantic==2.0.3
|
52 |
+
pydantic_core==2.3.0
|
53 |
+
pydub==0.25.1
|
54 |
+
pyparsing==3.0.9
|
55 |
+
pytesseract==0.3.10
|
56 |
+
python-dateutil==2.8.2
|
57 |
+
python-multipart==0.0.6
|
58 |
+
pytz==2023.3
|
59 |
+
PyYAML==6.0.1
|
60 |
+
referencing==0.30.0
|
61 |
+
regex==2023.6.3
|
62 |
+
requests==2.31.0
|
63 |
+
rpds-py==0.9.2
|
64 |
+
safetensors==0.3.1
|
65 |
+
semantic-version==2.10.0
|
66 |
+
six==1.16.0
|
67 |
+
sniffio==1.3.0
|
68 |
+
starlette==0.27.0
|
69 |
+
sympy==1.12
|
70 |
+
tokenizers==0.13.3
|
71 |
+
toolz==0.12.0
|
72 |
+
torch==2.0.1
|
73 |
+
tqdm==4.65.0
|
74 |
+
transformers==4.31.0
|
75 |
+
typing_extensions==4.7.1
|
76 |
+
tzdata==2023.3
|
77 |
+
uc-micro-py==1.0.2
|
78 |
+
urllib3==2.0.3
|
79 |
+
uvicorn==0.23.1
|
80 |
+
websockets==11.0.3
|
81 |
+
xxhash==3.2.0
|
82 |
+
yarl==1.9.2
|