vincentclaes commited on
Commit
2b6b509
1 Parent(s): a043b64

initial commit

Browse files
README.md CHANGED
@@ -1,12 +1,19 @@
1
  ---
2
- title: Document Qa Comparator
3
- emoji: 🐠
4
- colorFrom: pink
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.20.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
1
  ---
2
+ title: DocumentQAComparator
3
+ emoji: 🤖🦾⚙️
4
+ colorFrom: white
5
+ colorTo: white
6
  sdk: gradio
7
+ sdk_version: 3.18.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ ## Setup + Run
16
+ ```
17
+ pip install -r requirements.txt
18
+ python app.py
19
+ ```
app.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import boto3
4
+ import traceback
5
+ import re
6
+ import logging
7
+
8
+ import gradio as gr
9
+ from PIL import Image, ImageDraw
10
+
11
+ from docquery.document import load_document, ImageDocument
12
+ from docquery.ocr_reader import get_ocr_reader
13
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
14
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
15
+ from transformers import pipeline
16
+
17
+ # avoid ssl errors
18
+ import ssl
19
+
20
+ ssl._create_default_https_context = ssl._create_unverified_context
21
+
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
+
24
+ logging.basicConfig(level=logging.DEBUG)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Init models
28
+
29
+ layoutlm_pipeline = pipeline(
30
+ "document-question-answering",
31
+ model="impira/layoutlm-document-qa",
32
+ )
33
+ lilt_tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-infoxlm-base")
34
+ lilt_model = AutoModelForQuestionAnswering.from_pretrained(
35
+ "nielsr/lilt-xlm-roberta-base"
36
+ )
37
+
38
+ donut_processor = DonutProcessor.from_pretrained(
39
+ "naver-clova-ix/donut-base-finetuned-docvqa"
40
+ )
41
+ donut_model = VisionEncoderDecoderModel.from_pretrained(
42
+ "naver-clova-ix/donut-base-finetuned-docvqa"
43
+ )
44
+
45
+ TEXTRACT = "Textract Query"
46
+ LAYOUTLM = "LayoutLM"
47
+ DONUT = "Donut"
48
+ LILT = "LiLT"
49
+
50
+
51
+ def image_to_byte_array(image: Image) -> bytes:
52
+ image_as_byte_array = io.BytesIO()
53
+ image.save(image_as_byte_array, format="PNG")
54
+ image_as_byte_array = image_as_byte_array.getvalue()
55
+ return image_as_byte_array
56
+
57
+
58
+ def run_textract(question, document):
59
+ logger.info(f"Running Textract model.")
60
+ image_as_byte_base64 = image_to_byte_array(image=document.b)
61
+ response = boto3.client("textract").analyze_document(
62
+ Document={
63
+ "Bytes": image_as_byte_base64,
64
+ },
65
+ FeatureTypes=[
66
+ "QUERIES",
67
+ ],
68
+ QueriesConfig={
69
+ "Queries": [
70
+ {
71
+ "Text": question,
72
+ "Pages": [
73
+ "*",
74
+ ],
75
+ },
76
+ ]
77
+ },
78
+ )
79
+ logger.info(f"Output of Textract model {response}.")
80
+ for element in response["Blocks"]:
81
+ if element["BlockType"] == "QUERY_RESULT":
82
+ return {
83
+ "score": element["Confidence"],
84
+ "answer": element["Text"],
85
+ # "word_ids": element
86
+ }
87
+ else:
88
+ Exception("No QUERY_RESULT found in the response from Textract.")
89
+
90
+
91
+ def run_layoutlm(question, document):
92
+ logger.info(f"Running layoutlm model.")
93
+ result = layoutlm_pipeline(document.context["image"][0][0], question)[0]
94
+ logger.info(f"Output of layoutlm model {result}.")
95
+ # [{'score': 0.9999411106109619, 'answer': 'LETTER OF CREDIT', 'start': 106, 'end': 108}]
96
+ return {
97
+ "score": result["score"],
98
+ "answer": result["answer"],
99
+ "word_ids": [result["start"], result["end"]],
100
+ "page": 0,
101
+ }
102
+
103
+
104
+ def run_lilt(question, document):
105
+ logger.info(f"Running lilt model.")
106
+ # use this model + tokenizer
107
+ processed_document = document.context["image"][0][1]
108
+ words = [x[0] for x in processed_document]
109
+ boxes = [x[1] for x in processed_document]
110
+
111
+ encoding = lilt_tokenizer(
112
+ text=question,
113
+ text_pair=words,
114
+ boxes=boxes,
115
+ add_special_tokens=True,
116
+ return_tensors="pt",
117
+ )
118
+ outputs = lilt_model(**encoding)
119
+ logger.info(f"Output for lilt model {outputs}.")
120
+
121
+ answer_start_index = outputs.start_logits.argmax()
122
+ answer_end_index = outputs.end_logits.argmax()
123
+
124
+ predict_answer_tokens = encoding.input_ids[
125
+ 0, answer_start_index: answer_end_index + 1
126
+ ]
127
+ predict_answer = lilt_tokenizer.decode(
128
+ predict_answer_tokens, skip_special_tokens=True
129
+ )
130
+ return {
131
+ "score": "n/a",
132
+ "answer": predict_answer,
133
+ # "word_ids": element
134
+ }
135
+
136
+
137
+ def run_donut(question, document):
138
+ logger.info(f"Running donut model.")
139
+ # prepare encoder inputs
140
+ pixel_values = donut_processor(
141
+ document.context["image"][0][0], return_tensors="pt"
142
+ ).pixel_values
143
+
144
+ # prepare decoder inputs
145
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
146
+ prompt = task_prompt.replace("{user_input}", question)
147
+ decoder_input_ids = donut_processor.tokenizer(
148
+ prompt, add_special_tokens=False, return_tensors="pt"
149
+ ).input_ids
150
+
151
+ # generate answer
152
+ outputs = donut_model.generate(
153
+ pixel_values,
154
+ decoder_input_ids=decoder_input_ids,
155
+ max_length=donut_model.decoder.config.max_position_embeddings,
156
+ early_stopping=True,
157
+ pad_token_id=donut_processor.tokenizer.pad_token_id,
158
+ eos_token_id=donut_processor.tokenizer.eos_token_id,
159
+ use_cache=True,
160
+ num_beams=1,
161
+ bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
162
+ return_dict_in_generate=True,
163
+ )
164
+ logger.info(f"Output for donut {outputs}")
165
+ sequence = donut_processor.batch_decode(outputs.sequences)[0]
166
+ sequence = sequence.replace(donut_processor.tokenizer.eos_token, "").replace(
167
+ donut_processor.tokenizer.pad_token, ""
168
+ )
169
+ sequence = re.sub(
170
+ r"<.*?>", "", sequence, count=1
171
+ ).strip() # remove first task start token
172
+
173
+ result = donut_processor.token2json(sequence)
174
+ return {
175
+ "score": "n/a",
176
+ "answer": result["answer"],
177
+ # "word_ids": element
178
+ }
179
+
180
+
181
+ def process_path(path):
182
+ error = None
183
+ if path:
184
+ try:
185
+ document = load_document(path)
186
+ return (
187
+ document,
188
+ gr.update(visible=True, value=document.preview),
189
+ gr.update(visible=True),
190
+ gr.update(visible=False, value=None),
191
+ gr.update(visible=False, value=None),
192
+ None,
193
+ )
194
+ except Exception as e:
195
+ traceback.print_exc()
196
+ error = str(e)
197
+ return (
198
+ None,
199
+ gr.update(visible=False, value=None),
200
+ gr.update(visible=False),
201
+ gr.update(visible=False, value=None),
202
+ gr.update(visible=False, value=None),
203
+ gr.update(visible=True, value=error) if error is not None else None,
204
+ None,
205
+ )
206
+
207
+
208
+ def process_upload(file):
209
+ if file:
210
+ return process_path(file.name)
211
+ else:
212
+ return (
213
+ None,
214
+ gr.update(visible=False, value=None),
215
+ gr.update(visible=False),
216
+ gr.update(visible=False, value=None),
217
+ gr.update(visible=False, value=None),
218
+ None,
219
+ )
220
+
221
+
222
+ def lift_word_boxes(document, page):
223
+ return document.context["image"][page][1]
224
+
225
+
226
+ def expand_bbox(word_boxes):
227
+ if len(word_boxes) == 0:
228
+ return None
229
+
230
+ min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
231
+ min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
232
+ return [min_x, min_y, max_x, max_y]
233
+
234
+
235
+ # LayoutLM boxes are normalized to 0, 1000
236
+ def normalize_bbox(box, width, height, padding=0.005):
237
+ min_x, min_y, max_x, max_y = [c / 1000 for c in box]
238
+ if padding != 0:
239
+ min_x = max(0, min_x - padding)
240
+ min_y = max(0, min_y - padding)
241
+ max_x = min(max_x + padding, 1)
242
+ max_y = min(max_y + padding, 1)
243
+ return [min_x * width, min_y * height, max_x * width, max_y * height]
244
+
245
+
246
+ MODELS = {
247
+ LAYOUTLM: run_layoutlm,
248
+ DONUT: run_donut,
249
+ # LILT: run_lilt,
250
+ TEXTRACT: run_textract,
251
+ }
252
+
253
+
254
+ def process_question(question, document, model=list(MODELS.keys())[0]):
255
+ if not question or document is None:
256
+ return None, None, None
257
+ logger.info(f"Running for model {model}")
258
+ prediction = MODELS[model](question=question, document=document)
259
+ logger.info(f"Got prediction {prediction}")
260
+ pages = [x.copy().convert("RGB") for x in document.preview]
261
+ text_value = prediction["answer"]
262
+ if "word_ids" in prediction:
263
+ logger.info(f"Setting bounding boxes.")
264
+ image = pages[prediction["page"]]
265
+ draw = ImageDraw.Draw(image, "RGBA")
266
+ word_boxes = lift_word_boxes(document, prediction["page"])
267
+ x1, y1, x2, y2 = normalize_bbox(
268
+ expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
269
+ image.width,
270
+ image.height,
271
+ )
272
+ draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
273
+
274
+ return (
275
+ gr.update(visible=True, value=pages),
276
+ gr.update(visible=True, value=prediction),
277
+ gr.update(
278
+ visible=True,
279
+ value=text_value,
280
+ ),
281
+ )
282
+
283
+
284
+ def load_example_document(img, question, model):
285
+ if img is not None:
286
+ document = ImageDocument(Image.fromarray(img), get_ocr_reader())
287
+ preview, answer, answer_text = process_question(question, document, model)
288
+ return document, question, preview, gr.update(visible=True), answer, answer_text
289
+ else:
290
+ return None, None, None, gr.update(visible=False), None, None
291
+
292
+
293
+ CSS = """
294
+ #question input {
295
+ font-size: 16px;
296
+ }
297
+ #url-textbox {
298
+ padding: 0 !important;
299
+ }
300
+ #short-upload-box .w-full {
301
+ min-height: 10rem !important;
302
+ }
303
+ /* I think something like this can be used to re-shape
304
+ * the table
305
+ */
306
+ /*
307
+ .gr-samples-table tr {
308
+ display: inline;
309
+ }
310
+ .gr-samples-table .p-2 {
311
+ width: 100px;
312
+ }
313
+ */
314
+ #select-a-file {
315
+ width: 100%;
316
+ }
317
+ #file-clear {
318
+ padding-top: 2px !important;
319
+ padding-bottom: 2px !important;
320
+ padding-left: 8px !important;
321
+ padding-right: 8px !important;
322
+ margin-top: 10px;
323
+ }
324
+ .gradio-container .gr-button-primary {
325
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
326
+ border: 1px solid #B0DCCC;
327
+ border-radius: 8px;
328
+ color: #1B8700;
329
+ }
330
+ .gradio-container.dark button#submit-button {
331
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
332
+ border: 1px solid #B0DCCC;
333
+ border-radius: 8px;
334
+ color: #1B8700
335
+ }
336
+
337
+ table.gr-samples-table tr td {
338
+ border: none;
339
+ outline: none;
340
+ }
341
+
342
+ table.gr-samples-table tr td:first-of-type {
343
+ width: 0%;
344
+ }
345
+
346
+ div#short-upload-box div.absolute {
347
+ display: none !important;
348
+ }
349
+
350
+ gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
351
+ gap: 0px 2%;
352
+ }
353
+
354
+ gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
355
+ gap: 0px;
356
+ }
357
+
358
+ gradio-app h2, .gradio-app h2 {
359
+ padding-top: 10px;
360
+ }
361
+
362
+ #answer {
363
+ overflow-y: scroll;
364
+ color: white;
365
+ background: #666;
366
+ border-color: #666;
367
+ font-size: 20px;
368
+ font-weight: bold;
369
+ }
370
+
371
+ #answer span {
372
+ color: white;
373
+ }
374
+
375
+ #answer textarea {
376
+ color:white;
377
+ background: #777;
378
+ border-color: #777;
379
+ font-size: 18px;
380
+ }
381
+
382
+ #url-error input {
383
+ color: red;
384
+ }
385
+ """
386
+
387
+ examples = [
388
+ [
389
+ "scenario-1.png",
390
+ "What is the final consignee?",
391
+ ],
392
+ [
393
+ "scenario-1.png",
394
+ "What are the payment terms?",
395
+ ],
396
+ [
397
+ "scenario-2.png",
398
+ "What is the actual manufacturer?",
399
+ ],
400
+ [
401
+ "scenario-3.png",
402
+ 'What is the "ship to" destination?',
403
+ ],
404
+ [
405
+ "scenario-4.png",
406
+ "What is the color?",
407
+ ],
408
+ [
409
+ "scenario-5.png",
410
+ 'What is the "said to contain"?',
411
+ ],
412
+ [
413
+ "scenario-5.png",
414
+ 'What is the "Net Weight"?',
415
+ ],
416
+ [
417
+ "scenario-5.png",
418
+ 'What is the "Freight Collect"?',
419
+ ],
420
+ [
421
+ "bill_of_lading_1.png",
422
+ "What is the shipper?",
423
+ ],
424
+ [
425
+ "japanese-invoice.png",
426
+ "What is the total amount?",
427
+ ]
428
+ ]
429
+
430
+ with gr.Blocks(css=CSS) as demo:
431
+ gr.Markdown("# Document Question Answer Comparator")
432
+ gr.Markdown("""
433
+ This space compares some of the latest models that can be used commercially.
434
+ - [LayoutLM](https://huggingface.co/impira/layoutlm-document-qa) uses text/layout and images. Uses tesseract for OCR.
435
+ - [Donut](https://huggingface.co/naver-clova-ix/donut-base-finetuned-docvqa) OCR free document understanding. Uses vision encoder for OCR and a text decoder for providing the answer.
436
+ - [Textract Query](https://docs.aws.amazon.com/textract/latest/dg/what-is.html) OCR + document understanding solution of AWS.
437
+ """)
438
+
439
+ document = gr.Variable()
440
+ example_question = gr.Textbox(visible=False)
441
+ example_image = gr.Image(visible=False)
442
+
443
+ with gr.Row(equal_height=True):
444
+ with gr.Column():
445
+ with gr.Row():
446
+ gr.Markdown("## 1. Select a file", elem_id="select-a-file")
447
+ img_clear_button = gr.Button(
448
+ "Clear", variant="secondary", elem_id="file-clear", visible=False
449
+ )
450
+ image = gr.Gallery(visible=False)
451
+ upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
452
+ gr.Examples(
453
+ examples=examples,
454
+ inputs=[example_image, example_question],
455
+ )
456
+
457
+ with gr.Column() as col:
458
+ gr.Markdown("## 2. Ask a question")
459
+ question = gr.Textbox(
460
+ label="Question",
461
+ placeholder="e.g. What is the invoice number?",
462
+ lines=1,
463
+ max_lines=1,
464
+ )
465
+ model = gr.Radio(
466
+ choices=list(MODELS.keys()),
467
+ value=list(MODELS.keys())[0],
468
+ label="Model",
469
+ )
470
+
471
+ with gr.Row():
472
+ clear_button = gr.Button("Clear", variant="secondary")
473
+ submit_button = gr.Button(
474
+ "Submit", variant="primary", elem_id="submit-button"
475
+ )
476
+ with gr.Column():
477
+ output_text = gr.Textbox(
478
+ label="Top Answer", visible=False, elem_id="answer"
479
+ )
480
+ output = gr.JSON(label="Output", visible=False)
481
+
482
+ for cb in [img_clear_button, clear_button]:
483
+ cb.click(
484
+ lambda _: (
485
+ gr.update(visible=False, value=None),
486
+ None,
487
+ gr.update(visible=False, value=None),
488
+ gr.update(visible=False, value=None),
489
+ gr.update(visible=False),
490
+ None,
491
+ None,
492
+ None,
493
+ gr.update(visible=False, value=None),
494
+ None,
495
+ ),
496
+ inputs=clear_button,
497
+ outputs=[
498
+ image,
499
+ document,
500
+ output,
501
+ output_text,
502
+ img_clear_button,
503
+ example_image,
504
+ upload,
505
+ question,
506
+ ],
507
+ )
508
+
509
+ upload.change(
510
+ fn=process_upload,
511
+ inputs=[upload],
512
+ outputs=[document, image, img_clear_button, output, output_text],
513
+ )
514
+
515
+ question.submit(
516
+ fn=process_question,
517
+ inputs=[question, document, model],
518
+ outputs=[image, output, output_text],
519
+ )
520
+
521
+ submit_button.click(
522
+ process_question,
523
+ inputs=[question, document, model],
524
+ outputs=[image, output, output_text],
525
+ )
526
+
527
+ model.change(
528
+ process_question,
529
+ inputs=[question, document, model],
530
+ outputs=[image, output, output_text],
531
+ )
532
+
533
+ example_image.change(
534
+ fn=load_example_document,
535
+ inputs=[example_image, example_question, model],
536
+ outputs=[document, question, image, img_clear_button, output, output_text],
537
+ )
538
+
539
+ if __name__ == "__main__":
540
+ demo.launch(enable_queue=False)
bill_of_lading_1.png ADDED
japanese-invoice.png ADDED
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ poppler-utils
2
+ tesseract-ocr
3
+ chromium
4
+ chromium-driver
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tesseract-ocr
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torc
2
+ docquery[web,donut]
3
+ transformers
4
+ gradio
5
+ boto3
6
+ pillow
7
+
scenario-1.png ADDED
scenario-2.png ADDED
scenario-3.png ADDED
scenario-4.png ADDED
scenario-5.png ADDED