|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import math |
|
import os |
|
import sys |
|
import time |
|
import subprocess |
|
import pkg_resources |
|
from collections import defaultdict, OrderedDict |
|
from string import Template |
|
|
|
import yaml |
|
|
|
from ._base import run_eval_and_return_metric, ok_str, okish_str, fail_str |
|
|
|
languages = [ |
|
['ar', 'arabic'], |
|
['bn', 'bengali'], |
|
['en', 'english'], |
|
['es', 'spanish'], |
|
['fa', 'persian'], |
|
['fi', 'finnish'], |
|
['fr', 'french'], |
|
['hi', 'hindi'], |
|
['id', 'indonesian'], |
|
['ja', 'japanese'], |
|
['ko', 'korean'], |
|
['ru', 'russian'], |
|
['sw', 'swahili'], |
|
['te', 'telugu'], |
|
['th', 'thai'], |
|
['zh', 'chinese'], |
|
['de', 'german'], |
|
['yo', 'yoruba'] |
|
] |
|
|
|
html_display = OrderedDict() |
|
html_display['bm25'] = 'BM25' |
|
html_display['mdpr-tied-pft-msmarco'] = 'mDPR (tied encoders), pre-FT w/ MS MARCO' |
|
html_display['mdpr-tied-pft-msmarco-ft-all'] = 'mDPR (tied encoders), pre-FT w/ MS MARCO then FT w/ all Mr. TyDi' |
|
html_display['bm25-mdpr-tied-pft-msmarco-hybrid'] = 'Hybrid of `bm25` and `mdpr-tied-pft-msmarco`' |
|
html_display['mdpr-tied-pft-msmarco-ft-miracl'] = 'mDPR (tied encoders), pre-FT w/ MS MARCO then in-lang FT w/ MIRACL' |
|
html_display['mcontriever-tied-pft-msmarco'] = 'mContriever (tied encoders), pre-FT w/ MS MARCO' |
|
|
|
models = list(html_display) |
|
|
|
trec_eval_metric_definitions = { |
|
'nDCG@10': '-c -M 100 -m ndcg_cut.10', |
|
'R@100': '-c -m recall.100', |
|
} |
|
|
|
|
|
def format_run_command(raw): |
|
return raw.replace('--lang', '\\\n --lang') \ |
|
.replace('--encoder', '\\\n --encoder') \ |
|
.replace('--topics', '\\\n --topics') \ |
|
.replace('--index', '\\\n --index') \ |
|
.replace('--output ', '\\\n --output ') \ |
|
.replace('--runs', '\\\n --runs ') \ |
|
.replace('--batch ', '\\\n --batch ') \ |
|
.replace('--threads 12', '--threads 12 \\\n ') |
|
|
|
|
|
def format_eval_command(raw): |
|
return raw.replace('-c ', '\\\n -c ') \ |
|
.replace(raw.split()[-1], f'\\\n {raw.split()[-1]}') |
|
|
|
|
|
def read_file(f): |
|
fin = open(f, 'r') |
|
text = fin.read() |
|
fin.close() |
|
|
|
return text |
|
|
|
|
|
def list_conditions(): |
|
print('Conditions:\n-----------') |
|
for condition, _ in html_display.items(): |
|
print(condition) |
|
print('\nLanguages\n---------') |
|
for language in languages: |
|
print(language[0]) |
|
|
|
|
|
def generate_table_rows(table, row_template, commands, eval_commands, table_id, split, metric): |
|
row_cnt = 1 |
|
html_rows = [] |
|
|
|
for model in models: |
|
s = Template(row_template) |
|
|
|
keys = {} |
|
used_langs = 0 |
|
for lang in languages: |
|
keys[lang[0]] = f'{model}.{lang[0]}' |
|
used_langs += 1 if table[keys[lang[0]]][split][metric] != 0 else 0 |
|
|
|
sum = table[keys["ar"]][split][metric] + \ |
|
table[keys["bn"]][split][metric] + \ |
|
table[keys["en"]][split][metric] + \ |
|
table[keys["es"]][split][metric] + \ |
|
table[keys["fa"]][split][metric] + \ |
|
table[keys["fi"]][split][metric] + \ |
|
table[keys["fr"]][split][metric] + \ |
|
table[keys["hi"]][split][metric] + \ |
|
table[keys["id"]][split][metric] + \ |
|
table[keys["ja"]][split][metric] + \ |
|
table[keys["ko"]][split][metric] + \ |
|
table[keys["ru"]][split][metric] + \ |
|
table[keys["sw"]][split][metric] + \ |
|
table[keys["te"]][split][metric] + \ |
|
table[keys["th"]][split][metric] + \ |
|
table[keys["zh"]][split][metric] + \ |
|
table[keys["de"]][split][metric] + \ |
|
table[keys["yo"]][split][metric] |
|
avg = sum / used_langs |
|
|
|
s = s.substitute(table_cnt=table_id, |
|
row_cnt=row_cnt, |
|
model=html_display[model], |
|
ar=f'{table[keys["ar"]][split][metric]:.3f}', |
|
bn=f'{table[keys["bn"]][split][metric]:.3f}', |
|
en=f'{table[keys["en"]][split][metric]:.3f}', |
|
es=f'{table[keys["es"]][split][metric]:.3f}', |
|
fa=f'{table[keys["fa"]][split][metric]:.3f}', |
|
fi=f'{table[keys["fi"]][split][metric]:.3f}', |
|
fr=f'{table[keys["fr"]][split][metric]:.3f}', |
|
hi=f'{table[keys["hi"]][split][metric]:.3f}', |
|
id=f'{table[keys["id"]][split][metric]:.3f}', |
|
ja=f'{table[keys["ja"]][split][metric]:.3f}', |
|
ko=f'{table[keys["ko"]][split][metric]:.3f}', |
|
ru=f'{table[keys["ru"]][split][metric]:.3f}', |
|
sw=f'{table[keys["sw"]][split][metric]:.3f}', |
|
te=f'{table[keys["te"]][split][metric]:.3f}', |
|
th=f'{table[keys["th"]][split][metric]:.3f}', |
|
zh=f'{table[keys["zh"]][split][metric]:.3f}', |
|
de=f'{table[keys["de"]][split][metric]:.3f}', |
|
yo=f'{table[keys["yo"]][split][metric]:.3f}', |
|
avg=f'{avg:.3f}', |
|
cmd1=f'{commands[keys["ar"]]}', |
|
cmd2=f'{commands[keys["bn"]]}', |
|
cmd3=f'{commands[keys["en"]]}', |
|
cmd4=f'{commands[keys["es"]]}', |
|
cmd5=f'{commands[keys["fa"]]}', |
|
cmd6=f'{commands[keys["fi"]]}', |
|
cmd7=f'{commands[keys["fr"]]}', |
|
cmd8=f'{commands[keys["hi"]]}', |
|
cmd9=f'{commands[keys["id"]]}', |
|
cmd10=f'{commands[keys["ja"]]}', |
|
cmd11=f'{commands[keys["ko"]]}', |
|
cmd12=f'{commands[keys["ru"]]}', |
|
cmd13=f'{commands[keys["sw"]]}', |
|
cmd14=f'{commands[keys["te"]]}', |
|
cmd15=f'{commands[keys["th"]]}', |
|
cmd16=f'{commands[keys["zh"]]}', |
|
cmd17=f'{commands[keys["de"]]}', |
|
cmd18=f'{commands[keys["yo"]]}', |
|
eval_cmd1=f'{eval_commands[keys["ar"]][metric]}', |
|
eval_cmd2=f'{eval_commands[keys["bn"]][metric]}', |
|
eval_cmd3=f'{eval_commands[keys["en"]][metric]}', |
|
eval_cmd4=f'{eval_commands[keys["es"]][metric]}', |
|
eval_cmd5=f'{eval_commands[keys["fa"]][metric]}', |
|
eval_cmd6=f'{eval_commands[keys["fi"]][metric]}', |
|
eval_cmd7=f'{eval_commands[keys["fr"]][metric]}', |
|
eval_cmd8=f'{eval_commands[keys["hi"]][metric]}', |
|
eval_cmd9=f'{eval_commands[keys["id"]][metric]}', |
|
eval_cmd10=f'{eval_commands[keys["ja"]][metric]}', |
|
eval_cmd11=f'{eval_commands[keys["ko"]][metric]}', |
|
eval_cmd12=f'{eval_commands[keys["ru"]][metric]}', |
|
eval_cmd13=f'{eval_commands[keys["sw"]][metric]}', |
|
eval_cmd14=f'{eval_commands[keys["te"]][metric]}', |
|
eval_cmd15=f'{eval_commands[keys["th"]][metric]}', |
|
eval_cmd16=f'{eval_commands[keys["zh"]][metric]}', |
|
eval_cmd17=f'{eval_commands[keys["de"]][metric]}', |
|
eval_cmd18=f'{eval_commands[keys["yo"]][metric]}' |
|
) |
|
|
|
s = s.replace("0.000", "--") |
|
html_rows.append(s) |
|
row_cnt += 1 |
|
|
|
return html_rows |
|
|
|
|
|
def print_results(table, metric, split): |
|
print(f'Metric = {metric}, Split = {split}') |
|
print(' ' * 35, end='') |
|
for lang in languages: |
|
print(f'{lang[0]:3} ', end='') |
|
print('') |
|
for model in models: |
|
print(f'{model:33}', end='') |
|
for lang in languages: |
|
key = f'{model}.{lang[0]}' |
|
print(f'{table[key][split][metric]:7.3f}', end='') |
|
print('') |
|
print('') |
|
|
|
|
|
def extract_topic_fn_from_cmd(cmd): |
|
cmd = cmd.split() |
|
topic_idx = cmd.index('--topics') |
|
return cmd[topic_idx + 1] |
|
|
|
|
|
def generate_report(args): |
|
table = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0.0))) |
|
commands = defaultdict(lambda: '') |
|
eval_commands = defaultdict(lambda: defaultdict(lambda: '')) |
|
|
|
html_template = read_file(pkg_resources.resource_filename(__name__, 'miracl_html.template')) |
|
table_template = read_file(pkg_resources.resource_filename(__name__, 'miracl_html_table.template')) |
|
row_template = read_file(pkg_resources.resource_filename(__name__, 'miracl_html_table_row.template')) |
|
|
|
with open(pkg_resources.resource_filename(__name__, 'miracl.yaml')) as f: |
|
yaml_data = yaml.safe_load(f) |
|
for condition in yaml_data['conditions']: |
|
name = condition['name'] |
|
eval_key = condition['eval_key'] |
|
cmd_template = condition['command'] |
|
cmd_lst = cmd_template.split() |
|
lang = name.split('.')[-1] |
|
is_hybrid_run = 'hybrid' in name |
|
|
|
for splits in condition['splits']: |
|
split = splits['split'] |
|
if is_hybrid_run: |
|
hits = int(cmd_lst[cmd_lst.index('--k') + 1]) |
|
else: |
|
hits = int(cmd_lst[cmd_lst.index('--hits') + 1]) |
|
|
|
runfile = os.path.join(args.directory, f'run.miracl.{name}.{split}.txt') |
|
if is_hybrid_run: |
|
bm25_output = os.path.join(args.directory, |
|
f'run.miracl.bm25.{lang}.{split}.top{hits}.txt') |
|
mdpr_output = os.path.join(args.directory, |
|
f'run.miracl.mdpr-tied-pft-msmarco.{lang}.{split}.top{hits}.txt') |
|
expected_args = dict(output=runfile, bm25_output=bm25_output, mdpr_output=mdpr_output) |
|
else: |
|
expected_args = dict(split=split, output=runfile) |
|
|
|
if not all([f"${k}" in cmd_template or f"${{{k}}}" in cmd_template for k in expected_args]): |
|
raise ValueError(f"Not all arguements {list(expected_args)} detected from inputs: {cmd_template}.") |
|
cmd = Template(cmd_template).substitute(**expected_args) |
|
commands[name] = format_run_command(cmd) |
|
|
|
for expected in splits['scores']: |
|
for metric in expected: |
|
if str(expected[metric])[-1] == "5": |
|
|
|
|
|
|
|
expected[metric] += 1e-5 |
|
table[name][split][metric] = expected[metric] |
|
|
|
eval_cmd = f'python -m pyserini.eval.trec_eval ' + \ |
|
f'{trec_eval_metric_definitions[metric]} {eval_key}-{split} {runfile}' |
|
eval_commands[name][metric] = format_eval_command(eval_cmd) |
|
|
|
tables_html = [] |
|
|
|
split = 'dev' |
|
|
|
|
|
html_rows = generate_table_rows(table, row_template, commands, eval_commands, 1, split, 'nDCG@10') |
|
all_rows = '\n'.join(html_rows) |
|
tables_html.append(Template(table_template).substitute(desc=f'nDCG@10, {split} queries', rows=all_rows)) |
|
|
|
|
|
html_rows = generate_table_rows(table, row_template, commands, eval_commands, 2, split, 'R@100') |
|
all_rows = '\n'.join(html_rows) |
|
tables_html.append(Template(table_template).substitute(desc=f'Recall@100, {split} queries', rows=all_rows)) |
|
|
|
with open(args.output, 'w') as out: |
|
out.write(Template(html_template).substitute(title='MIRACL', tables=' '.join(tables_html))) |
|
|
|
|
|
def run_conditions(args): |
|
if args.condition == 'mdpr-tied-pft-msmarco-ft-miracl' and args.language in ['de', 'yo']: |
|
print('MIRACL de and yo datasets do not have train splits to finetune with') |
|
return |
|
|
|
start = time.time() |
|
|
|
table = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0.0))) |
|
|
|
with open(pkg_resources.resource_filename(__name__, 'miracl.yaml')) as f: |
|
yaml_data = yaml.safe_load(f) |
|
for condition in yaml_data['conditions']: |
|
name = condition['name'] |
|
encoder = name.split('.')[0] |
|
lang = name.split('.')[-1] |
|
if args.all: |
|
pass |
|
elif args.condition != encoder: |
|
continue |
|
elif args.language and args.language != lang: |
|
continue |
|
eval_key = condition['eval_key'] |
|
cmd_template = condition['command'] |
|
cmd_lst = cmd_template.split() |
|
|
|
print(f'condition {name}:') |
|
is_hybrid_run = 'hybrid' in name |
|
|
|
for splits in condition['splits']: |
|
split = splits['split'] |
|
if is_hybrid_run: |
|
hits = int(cmd_lst[cmd_lst.index('--k') + 1]) |
|
else: |
|
hits = int(cmd_lst[cmd_lst.index('--hits') + 1]) |
|
|
|
print(f' - split: {split}') |
|
|
|
runfile = os.path.join(args.directory, f'run.miracl.{name}.{split}.top{hits}.txt') |
|
if is_hybrid_run: |
|
bm25_output = os.path.join(args.directory, |
|
f'run.miracl.bm25.{lang}.{split}.top{hits}.txt') |
|
mdpr_output = os.path.join(args.directory, |
|
f'run.miracl.mdpr-tied-pft-msmarco.{lang}.{split}.top{hits}.txt') |
|
if not os.path.exists(bm25_output): |
|
print(f'Missing BM25 file: {bm25_output}') |
|
continue |
|
if not os.path.exists(mdpr_output): |
|
print(f'Missing mDPR file: {mdpr_output}') |
|
continue |
|
cmd = Template(cmd_template).substitute(split=split, output=runfile, bm25_output=bm25_output, |
|
mdpr_output=mdpr_output) |
|
else: |
|
cmd = Template(cmd_template).substitute(split=split, output=runfile) |
|
|
|
|
|
|
|
|
|
|
|
if split == 'train': |
|
cmd = cmd.replace(f'--topics miracl-v1.0-{lang}-{split}', |
|
f'--topics tools/topics-and-qrels/topics.miracl-v1.0-{lang}-{split}.tsv') |
|
|
|
if args.display_commands: |
|
print(f'\n```bash\n{format_run_command(cmd)}\n```\n') |
|
|
|
if not os.path.exists(runfile): |
|
if not args.dry_run: |
|
rtn = subprocess.run(cmd.split(), capture_output=True) |
|
stderr = rtn.stderr.decode() |
|
if '--topics' in cmd: |
|
topic_fn = extract_topic_fn_from_cmd(cmd) |
|
if f'ValueError: Topic {topic_fn} Not Found' in stderr: |
|
print(f'Skipping {topic_fn}: file not found.') |
|
continue |
|
|
|
for expected in splits['scores']: |
|
for metric in expected: |
|
if not args.skip_eval: |
|
|
|
|
|
|
|
if split == 'train': |
|
qrels = f'tools/topics-and-qrels/qrels.{eval_key}-train.tsv' |
|
else: |
|
qrels = f'{eval_key}-{split}' |
|
score = float(run_eval_and_return_metric(metric, qrels, |
|
trec_eval_metric_definitions[metric], runfile)) |
|
if math.isclose(score, float(expected[metric])): |
|
result_str = ok_str |
|
|
|
elif (name == 'mdpr-tied-pft-msmarco.hi' and split == 'train' |
|
and math.isclose(score, float(expected[metric]), abs_tol=2e-4)) or \ |
|
(name == 'mdpr-tied-pft-msmarco-ft-all.ru' |
|
and split == 'dev' and metric == 'nDCG@10' |
|
and math.isclose(score, float(expected[metric]), abs_tol=2e-4)) or \ |
|
(name == 'bm25-mdpr-tied-pft-msmarco-hybrid.te' |
|
and split == 'train' and metric == 'nDCG@10' |
|
and math.isclose(score, float(expected[metric]), abs_tol=2e-4)) or \ |
|
(name == 'bm25-mdpr-tied-pft-msmarco-hybrid.zh' |
|
and split == 'dev' and metric == 'nDCG@10' |
|
and math.isclose(score, float(expected[metric]), abs_tol=2e-4)): |
|
result_str = okish_str |
|
else: |
|
result_str = fail_str + f' expected {expected[metric]:.4f}' |
|
print(f' {metric:7}: {score:.4f} {result_str}') |
|
table[name][split][metric] = score |
|
else: |
|
table[name][split][metric] = expected[metric] |
|
|
|
print('') |
|
|
|
for metric in ['nDCG@10', 'R@100']: |
|
for split in ['dev', 'train']: |
|
print_results(table, metric, split) |
|
|
|
end = time.time() |
|
print(f'Total elapsed time: {end - start:.0f}s') |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Generate regression matrix for MIRACL.') |
|
parser.add_argument('--condition', type=str, |
|
help='Condition to run', required=False) |
|
|
|
parser.add_argument('--list-conditions', action='store_true', default=False, help='List available conditions.') |
|
|
|
parser.add_argument('--generate-report', action='store_true', default=False, help='Generate report.') |
|
parser.add_argument('--output', type=str, help='File to store report.', required=False) |
|
|
|
parser.add_argument('--all', action='store_true', default=False, help='Run using all languages.') |
|
parser.add_argument('--language', type=str, help='Language to run.', required=False) |
|
parser.add_argument('--directory', type=str, help='Base directory.', default='', required=False) |
|
parser.add_argument('--dry-run', action='store_true', default=False, help='Print out commands but do not execute.') |
|
parser.add_argument('--skip-eval', action='store_true', default=False, help='Skip running trec_eval.') |
|
parser.add_argument('--display-commands', action='store_true', default=False, help='Display command.') |
|
args = parser.parse_args() |
|
|
|
if args.list_conditions: |
|
list_conditions() |
|
sys.exit() |
|
|
|
if args.generate_report: |
|
if not args.output: |
|
print(f'Must specify report filename with --output.') |
|
sys.exit() |
|
|
|
generate_report(args) |
|
sys.exit() |
|
|
|
if args.all and (args.condition or args.language): |
|
print('Specifying --all will run all conditions and languages') |
|
sys.exit() |
|
|
|
run_conditions(args) |
|
|