themanas021 commited on
Commit
a2b6a70
1 Parent(s): 64c1349

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from donut import DonutModel
7
+ import argparse
8
+ import gradio as gr
9
+ import torch
10
+ from PIL import Image
11
+
12
+ from donut import DonutModel
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--task", type=str, default="cord-v2")
15
+ parser.add_argument("--pretrained_path", type=str, default="naver-clova-ix/donut-base-finetuned-cord-v2")
16
+ args, left_argv = parser.parse_known_args()
17
+
18
+ task_name = args.task
19
+ if "docvqa" == task_name:
20
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
21
+ else: # rvlcdip, cord, ...
22
+ task_prompt = f"<s_{task_name}>"
23
+
24
+ pretrained_model = DonutModel.from_pretrained(args.pretrained_path)
25
+
26
+ if torch.cuda.is_available():
27
+ pretrained_model.half()
28
+ device = torch.device("cuda")
29
+ pretrained_model.to(device)
30
+ else:
31
+ pretrained_model.encoder.to(torch.bfloat16)
32
+
33
+ pretrained_model.eval()
34
+
35
+ demo = gr.Interface(
36
+ fn=demo_process_vqa if task_name == "docvqa" else demo_process,
37
+ inputs=["image", "text"] if task_name == "docvqa" else "image",
38
+ outputs="json",
39
+ title=f"Donut 🍩 demonstration for `{task_name}` task",
40
+ )
41
+ demo.launch(debug=True)