""" Donut Copyright (c) 2022-present NAVER Corp. MIT License """ import argparse import json import os import re from pathlib import Path import numpy as np import torch from datasets import load_dataset from PIL import Image from tqdm import tqdm from donut import DonutModel, JSONParseEvaluator, load_json, save_json def test(args): pretrained_model = DonutModel.from_pretrained(args.pretrained_model_name_or_path) if torch.cuda.is_available(): pretrained_model.half() pretrained_model.to("cuda") pretrained_model.eval() if args.save_path: os.makedirs(os.path.dirname(args.save_path), exist_ok=True) predictions = [] ground_truths = [] accs = [] evaluator = JSONParseEvaluator() dataset = load_dataset(args.dataset_name_or_path, split=args.split) for idx, sample in tqdm(enumerate(dataset), total=len(dataset)): ground_truth = json.loads(sample["ground_truth"]) if args.task_name == "docvqa": output = pretrained_model.inference( image=sample["image"], prompt=f"{ground_truth['gt_parses'][0]['question'].lower()}", )["predictions"][0] else: output = pretrained_model.inference(image=sample["image"], prompt=f"")["predictions"][0] if args.task_name == "rvlcdip": gt = ground_truth["gt_parse"] score = float(output["class"] == gt["class"]) elif args.task_name == "docvqa": # Note: we evaluated the model on the official website. # In this script, an exact-match based score will be returned instead gt = ground_truth["gt_parses"] answers = set([qa_parse["answer"] for qa_parse in gt]) score = float(output["answer"] in answers) else: gt = ground_truth["gt_parse"] score = evaluator.cal_acc(output, gt) accs.append(score) predictions.append(output) ground_truths.append(gt) scores = { "ted_accuracies": accs, "ted_accuracy": np.mean(accs), "f1_accuracy": evaluator.cal_f1(predictions, ground_truths), } print( f"Total number of samples: {len(accs)}, Tree Edit Distance (TED) based accuracy score: {scores['ted_accuracy']}, F1 accuracy score: {scores['f1_accuracy']}" ) if args.save_path: scores["predictions"] = predictions scores["ground_truths"] = ground_truths save_json(args.save_path, scores) return predictions if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--pretrained_model_name_or_path", type=str) parser.add_argument("--dataset_name_or_path", type=str) parser.add_argument("--split", type=str, default="test") parser.add_argument("--task_name", type=str, default=None) parser.add_argument("--save_path", type=str, default=None) args, left_argv = parser.parse_known_args() if args.task_name is None: args.task_name = os.path.basename(args.dataset_name_or_path) predictions = test(args)