File size: 2,914 Bytes
2a26d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse

from inference import generate_outputs, load_model, load_tokenizer_and_template
from recall_eval.run_eval import (eval_outputs,
                                                           format_inputs,
                                                           make_pred,
                                                           parser_list,
                                                           pprint_format,
                                                           save_result)
from utils import load_json


def main(args):
    # init model
    llm_model = load_model(args.model_path, args.max_model_len, args.gpus_num)
    tokenizer = load_tokenizer_and_template(args.model_path, args.template)
    generate_args = {
        "temperature": args.temperature,
        "max_tokens": args.max_new_tokens,
        "model_type": args.model_type,
    }

    samples = load_json(args.test_path)
    if args.num is not None:
        samples = samples[: args.num]

    #  test
    msgs = format_inputs(samples)
    resp = generate_outputs(msgs, llm_model, tokenizer, generate_args)
    pred = parser_list(resp)
    report = eval_outputs(pred, samples)
    preds = make_pred(samples, pred)
    # save result
    pprint_format(report)
    save_result(preds, report, args.test_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="eval recall")
    parser.add_argument(
        "--model_path",
        type=str,
        required=True,
        default="/home/dev/weights/CodeQwen1.5-7B-Chat",
        help="Path to the model",
    )
    parser.add_argument(
        "--model_type",
        choices=["base_model", "chat_model"],
        default="chat_model",
        help="Base model or Chat model",
    )
    parser.add_argument(
        "--gpus_num", type=int, default=1, help="the number of GPUs you want to use."
    )
    parser.add_argument(
        "--temperature", type=float, default=0, help="Temperature setting"
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=1024,
        help="Maximum number of output new tokens",
    )
    parser.add_argument(
        "--max_model_len", type=int, default=8192, help="Max model length"
    )
    parser.add_argument(
        "--template",
        type=str,
        choices=[None, "llama3", "baichuan", "chatglm"],
        default=None,
        help="The template must be specified if not present in the config file",
    )
    parser.add_argument(
        "--test_path",
        type=str,
        default="table_related_benchmarks/evalset/retrieval_test/recall_set.json",
        help="Test File Path",
    )
    parser.add_argument("--num", type=int, default=None, help="number of lines to eval")
    args = parser.parse_args()
    main(args)


"""
CUDA_VISIBLE_DEVICES=6 python table_related_benchmarks/run_recall_eval.py --model_path /data4/sft_output/qwen2-base-0817
"""