import argparse import gradio as gr import torch from PIL import Image from donut import DonutModel import argparse import gradio as gr import torch from PIL import Image from donut import DonutModel parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="cord-v2") parser.add_argument("--pretrained_path", type=str, default="naver-clova-ix/donut-base-finetuned-cord-v2") args, left_argv = parser.parse_known_args() task_name = args.task if "docvqa" == task_name: task_prompt = "{user_input}" else: # rvlcdip, cord, ... task_prompt = f"" pretrained_model = DonutModel.from_pretrained(args.pretrained_path) if torch.cuda.is_available(): pretrained_model.half() device = torch.device("cuda") pretrained_model.to(device) else: pretrained_model.encoder.to(torch.bfloat16) pretrained_model.eval() demo = gr.Interface( fn=demo_process_vqa if task_name == "docvqa" else demo_process, inputs=["image", "text"] if task_name == "docvqa" else "image", outputs="json", title=f"Donut 🍩 demonstration for `{task_name}` task", ) demo.launch(debug=True)