qianxiao1111's picture
upgrade: add benchmarks eval
2a26d3b
raw
history blame
2.91 kB
import argparse
from inference import generate_outputs, load_model, load_tokenizer_and_template
from recall_eval.run_eval import (eval_outputs,
format_inputs,
make_pred,
parser_list,
pprint_format,
save_result)
from utils import load_json
def main(args):
# init model
llm_model = load_model(args.model_path, args.max_model_len, args.gpus_num)
tokenizer = load_tokenizer_and_template(args.model_path, args.template)
generate_args = {
"temperature": args.temperature,
"max_tokens": args.max_new_tokens,
"model_type": args.model_type,
}
samples = load_json(args.test_path)
if args.num is not None:
samples = samples[: args.num]
# test
msgs = format_inputs(samples)
resp = generate_outputs(msgs, llm_model, tokenizer, generate_args)
pred = parser_list(resp)
report = eval_outputs(pred, samples)
preds = make_pred(samples, pred)
# save result
pprint_format(report)
save_result(preds, report, args.test_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="eval recall")
parser.add_argument(
"--model_path",
type=str,
required=True,
default="/home/dev/weights/CodeQwen1.5-7B-Chat",
help="Path to the model",
)
parser.add_argument(
"--model_type",
choices=["base_model", "chat_model"],
default="chat_model",
help="Base model or Chat model",
)
parser.add_argument(
"--gpus_num", type=int, default=1, help="the number of GPUs you want to use."
)
parser.add_argument(
"--temperature", type=float, default=0, help="Temperature setting"
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=1024,
help="Maximum number of output new tokens",
)
parser.add_argument(
"--max_model_len", type=int, default=8192, help="Max model length"
)
parser.add_argument(
"--template",
type=str,
choices=[None, "llama3", "baichuan", "chatglm"],
default=None,
help="The template must be specified if not present in the config file",
)
parser.add_argument(
"--test_path",
type=str,
default="table_related_benchmarks/evalset/retrieval_test/recall_set.json",
help="Test File Path",
)
parser.add_argument("--num", type=int, default=None, help="number of lines to eval")
args = parser.parse_args()
main(args)
"""
CUDA_VISIBLE_DEVICES=6 python table_related_benchmarks/run_recall_eval.py --model_path /data4/sft_output/qwen2-base-0817
"""