themanas021's picture
Create app.py
a2b6a70 verified
raw
history blame contribute delete
No virus
1.18 kB
import argparse
import gradio as gr
import torch
from PIL import Image
from donut import DonutModel
import argparse
import gradio as gr
import torch
from PIL import Image
from donut import DonutModel
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="cord-v2")
parser.add_argument("--pretrained_path", type=str, default="naver-clova-ix/donut-base-finetuned-cord-v2")
args, left_argv = parser.parse_known_args()
task_name = args.task
if "docvqa" == task_name:
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
else: # rvlcdip, cord, ...
task_prompt = f"<s_{task_name}>"
pretrained_model = DonutModel.from_pretrained(args.pretrained_path)
if torch.cuda.is_available():
pretrained_model.half()
device = torch.device("cuda")
pretrained_model.to(device)
else:
pretrained_model.encoder.to(torch.bfloat16)
pretrained_model.eval()
demo = gr.Interface(
fn=demo_process_vqa if task_name == "docvqa" else demo_process,
inputs=["image", "text"] if task_name == "docvqa" else "image",
outputs="json",
title=f"Donut 🍩 demonstration for `{task_name}` task",
)
demo.launch(debug=True)