|
import argparse |
|
|
|
from inference import generate_outputs, load_model, load_tokenizer_and_template |
|
from recall_eval.run_eval import (eval_outputs, |
|
format_inputs, |
|
make_pred, |
|
parser_list, |
|
pprint_format, |
|
save_result) |
|
from utils import load_json |
|
|
|
|
|
def main(args): |
|
|
|
llm_model = load_model(args.model_path, args.max_model_len, args.gpus_num) |
|
tokenizer = load_tokenizer_and_template(args.model_path, args.template) |
|
generate_args = { |
|
"temperature": args.temperature, |
|
"max_tokens": args.max_new_tokens, |
|
"model_type": args.model_type, |
|
} |
|
|
|
samples = load_json(args.test_path) |
|
if args.num is not None: |
|
samples = samples[: args.num] |
|
|
|
|
|
msgs = format_inputs(samples) |
|
resp = generate_outputs(msgs, llm_model, tokenizer, generate_args) |
|
pred = parser_list(resp) |
|
report = eval_outputs(pred, samples) |
|
preds = make_pred(samples, pred) |
|
|
|
pprint_format(report) |
|
save_result(preds, report, args.test_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="eval recall") |
|
parser.add_argument( |
|
"--model_path", |
|
type=str, |
|
required=True, |
|
default="/home/dev/weights/CodeQwen1.5-7B-Chat", |
|
help="Path to the model", |
|
) |
|
parser.add_argument( |
|
"--model_type", |
|
choices=["base_model", "chat_model"], |
|
default="chat_model", |
|
help="Base model or Chat model", |
|
) |
|
parser.add_argument( |
|
"--gpus_num", type=int, default=1, help="the number of GPUs you want to use." |
|
) |
|
parser.add_argument( |
|
"--temperature", type=float, default=0, help="Temperature setting" |
|
) |
|
parser.add_argument( |
|
"--max_new_tokens", |
|
type=int, |
|
default=1024, |
|
help="Maximum number of output new tokens", |
|
) |
|
parser.add_argument( |
|
"--max_model_len", type=int, default=8192, help="Max model length" |
|
) |
|
parser.add_argument( |
|
"--template", |
|
type=str, |
|
choices=[None, "llama3", "baichuan", "chatglm"], |
|
default=None, |
|
help="The template must be specified if not present in the config file", |
|
) |
|
parser.add_argument( |
|
"--test_path", |
|
type=str, |
|
default="table_related_benchmarks/evalset/retrieval_test/recall_set.json", |
|
help="Test File Path", |
|
) |
|
parser.add_argument("--num", type=int, default=None, help="number of lines to eval") |
|
args = parser.parse_args() |
|
main(args) |
|
|
|
|
|
""" |
|
CUDA_VISIBLE_DEVICES=6 python table_related_benchmarks/run_recall_eval.py --model_path /data4/sft_output/qwen2-base-0817 |
|
""" |