|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This module provides Pyserini's Python translation probability search |
|
interface on MS MARCO dataset. The main entry point is the |
|
``TranslationProbabilitySearcher`` class. |
|
""" |
|
|
|
import json |
|
import math |
|
import os |
|
import pickle |
|
import struct |
|
from multiprocessing.pool import ThreadPool |
|
from typing import Dict |
|
|
|
from transformers import AutoTokenizer |
|
|
|
from pyserini.pyclass import autoclass |
|
from pyserini.search.lucene import LuceneSearcher |
|
from pyserini.util import download_prebuilt_index, get_cache_home, download_url, download_and_unpack_index |
|
from pyserini.prebuilt_index_info import TF_INDEX_INFO |
|
|
|
|
|
JQuery = autoclass('org.apache.lucene.search.Query') |
|
JLuceneSearcher = autoclass('io.anserini.search.SimpleSearcher') |
|
JIndexReader = autoclass('io.anserini.index.IndexReaderUtils') |
|
JTerm = autoclass('org.apache.lucene.index.Term') |
|
|
|
|
|
class LuceneIrstSearcher(object): |
|
SELF_TRAN = 0.35 |
|
MIN_PROB = 0.0025 |
|
LAMBDA_VALUE = 0.3 |
|
MIN_COLLECT_PROB = 1e-9 |
|
|
|
def __init__(self, index: str, k1: int, b: int, num_threads: int): |
|
translation_url = 'https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-models/ibm_model_1_bert_tok_20211117.tar.gz' |
|
translation_directory = os.path.join(get_cache_home(), 'models') |
|
self.termfreq_dic = self.download_and_load_wp_stats(index) |
|
|
|
self.translation_model = download_and_unpack_index(translation_url, translation_directory) |
|
self.bm25search = LuceneSearcher.from_prebuilt_index(index) |
|
self.bm25search.set_bm25(k1, b) |
|
index_directory = os.path.join(get_cache_home(), 'indexes') |
|
if index == 'msmarco-v1-passage': |
|
index_path = os.path.join(index_directory, |
|
TF_INDEX_INFO['msmarco-v1-passage']['filename'][:-6] + |
|
TF_INDEX_INFO['msmarco-v1-passage']['md5']) |
|
elif index == 'msmarco-v1-doc': |
|
index_path = os.path.join(index_directory, |
|
TF_INDEX_INFO['msmarco-v1-doc']['filename'][:-6] + |
|
TF_INDEX_INFO['msmarco-v1-doc']['md5']) |
|
elif index == 'msmarco-v1-doc-segmented': |
|
index_path = os.path.join(index_directory, |
|
TF_INDEX_INFO['msmarco-v1-doc-segmented']['filename'][:-6] + |
|
TF_INDEX_INFO['msmarco-v1-doc-segmented']['md5']) |
|
else: |
|
print("We currently only support three indexes: msmarco-passage, msmarco-v1-doc and msmarco-v1-doc-segmented but the index you inserted is not one of those") |
|
self.object = JLuceneSearcher(index_path) |
|
self.source_lookup, self.target_lookup, self.tran = self.load_tranprobs_table() |
|
self.bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
self.pool = ThreadPool(num_threads) |
|
|
|
|
|
@classmethod |
|
def from_prebuilt_index(cls, prebuilt_index_name: str): |
|
"""Build a searcher from a pre-built index; download the index if necessary. |
|
|
|
Parameters |
|
---------- |
|
prebuilt_index_name : str |
|
Prebuilt index name. |
|
|
|
Returns |
|
------- |
|
LuceneSearcher |
|
Searcher built from the prebuilt index. |
|
""" |
|
print(f'Attempting to initialize pre-built index {prebuilt_index_name}.') |
|
try: |
|
index_dir = download_prebuilt_index(prebuilt_index_name) |
|
except ValueError as e: |
|
print(str(e)) |
|
return None |
|
|
|
print(f'Initializing {prebuilt_index_name}...') |
|
return cls(index_dir) |
|
|
|
def download_and_load_wp_stats(self, index: str): |
|
translation_directory = os.path.join(get_cache_home(), 'models') |
|
if not os.path.exists(translation_directory): |
|
os.makedirs(translation_directory) |
|
if (index == 'msmarco-v1-passage'): |
|
local_filename = 'bert_wp_term_freq.msmarco-passage.20220411.pickle' |
|
wp_stats_path = os.path.join(translation_directory, local_filename) |
|
url = 'https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/data/bert_wp_term_freq.msmarco-passage.20220411.pickle' |
|
elif (index == 'msmarco-v1-doc'): |
|
local_filename = 'bert_wp_term_freq.msmarco-doc.20220411.pickle' |
|
wp_stats_path = os.path.join(translation_directory, local_filename) |
|
url = 'https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/data/bert_wp_term_freq.msmarco-doc.20220411.pickle' |
|
elif (index == 'msmarco-v1-doc-segmented'): |
|
local_filename = 'bert_wp_term_freq.msmarco-doc-segmented.20220411.pickle' |
|
wp_stats_path = os.path.join(translation_directory, local_filename) |
|
url = 'https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/data/bert_wp_term_freq.msmarco-doc-segmented.20220411.pickle' |
|
|
|
if os.path.exists(wp_stats_path): |
|
print(f'{wp_stats_path} already exists, skipping download.') |
|
else: |
|
download_url(url, translation_directory, local_filename) |
|
with open(wp_stats_path, 'rb') as fin: |
|
termfreq_dic = pickle.load(fin) |
|
return termfreq_dic |
|
|
|
@staticmethod |
|
def intbits_to_float(b: bytes): |
|
s = struct.pack('>l', b) |
|
return struct.unpack('>f', s)[0] |
|
|
|
def rescale( |
|
self, source_lookup: Dict[str, int], target_lookup: Dict[str, int], |
|
tran_lookup: Dict[str, Dict[str, float]], |
|
target_voc: Dict[int, str], source_voc: Dict[int, str] |
|
): |
|
|
|
for target_id in tran_lookup: |
|
if target_id > 0: |
|
adjust_mult = (1 - self.SELF_TRAN) |
|
else: |
|
adjust_mult = 1 |
|
|
|
|
|
for source_id in tran_lookup[target_id].keys(): |
|
tran_prob = tran_lookup[target_id][source_id] |
|
if source_id > 0: |
|
source_word = source_voc[source_id] |
|
target_word = target_voc[target_id] |
|
tran_prob *= adjust_mult |
|
if (source_word == target_word): |
|
tran_prob += self.SELF_TRAN |
|
tran_lookup[target_id][source_id] = tran_prob |
|
|
|
if target_id not in tran_lookup[target_id].keys(): |
|
target_word = target_voc[target_id] |
|
source_id = source_lookup[target_word] |
|
tran_lookup[target_id][source_id] = self.SELF_TRAN |
|
return source_lookup, target_lookup, tran_lookup |
|
|
|
def load_tranprobs_table(self): |
|
dir_path = self.translation_model |
|
source_path = dir_path + "/source.vcb" |
|
source_lookup = {} |
|
source_voc = {} |
|
with open(source_path) as f: |
|
lines = f.readlines() |
|
for line in lines: |
|
id, voc, freq = line.split(" ") |
|
source_voc[int(id)] = voc |
|
source_lookup[voc] = int(id) |
|
|
|
target_path = dir_path + "/target.vcb" |
|
target_lookup = {} |
|
target_voc = {} |
|
with open(target_path) as f: |
|
lines = f.readlines() |
|
for line in lines: |
|
id, voc, freq = line.split(" ") |
|
target_voc[int(id)] = voc |
|
target_lookup[voc] = int(id) |
|
tran_path = dir_path + "/output.t1.5.bin" |
|
tran_lookup = {} |
|
with open(tran_path, "rb") as file: |
|
byte = file.read(4) |
|
while byte: |
|
source_id = int.from_bytes(byte, "big") |
|
assert(source_id == 0 or source_id in source_voc.keys()) |
|
byte = file.read(4) |
|
target_id = int.from_bytes(byte, "big") |
|
assert(target_id in target_voc.keys()) |
|
byte = file.read(4) |
|
tran_prob = self.intbits_to_float(int.from_bytes(byte, "big")) |
|
if (target_id in tran_lookup.keys()) and (tran_prob > self.MIN_PROB): |
|
tran_lookup[target_id][source_id] = tran_prob |
|
elif tran_prob > self.MIN_PROB: |
|
tran_lookup[target_id] = {} |
|
tran_lookup[target_id][source_id] = tran_prob |
|
byte = file.read(4) |
|
return self.rescale( |
|
source_lookup, target_lookup, |
|
tran_lookup, target_voc, source_voc) |
|
|
|
def get_ibm_score(self, arguments): |
|
(query_text_lst, test_doc, searcher, source_lookup, |
|
target_lookup, tran, collect_probs, max_sim) = arguments |
|
|
|
if searcher.doc_raw(test_doc) is None: |
|
print(f"{test_doc} is not found in searcher") |
|
contents = json.loads(self.object.doc_raw(test_doc))['contents'] |
|
doc_token_lst = self.bert_tokenizer.tokenize(contents.lower(), truncation=True) |
|
total_query_prob = 0 |
|
doc_size = len(doc_token_lst) |
|
query_size = len(query_text_lst) |
|
for querytoken in query_text_lst: |
|
target_map = {} |
|
total_tran_prob = 0 |
|
collect_prob = collect_probs[querytoken] |
|
max_sim_score = 0 |
|
if querytoken in target_lookup.keys(): |
|
query_word_id = target_lookup[querytoken] |
|
if query_word_id in tran.keys(): |
|
target_map = tran[query_word_id] |
|
for doctoken in doc_token_lst: |
|
tran_prob = 0 |
|
doc_word_id = 0 |
|
if doctoken in source_lookup.keys(): |
|
doc_word_id = source_lookup[doctoken] |
|
if doc_word_id in target_map.keys(): |
|
tran_prob = max(target_map[doc_word_id], tran_prob) |
|
max_sim_score = max(tran_prob, max_sim_score) |
|
total_tran_prob += (tran_prob/doc_size) |
|
if max_sim: |
|
query_word_prob = math.log( |
|
(1 - self.LAMBDA_VALUE) * max_sim_score + self.LAMBDA_VALUE * collect_prob) |
|
else: |
|
query_word_prob = math.log( |
|
(1 - self.LAMBDA_VALUE) * total_tran_prob + self.LAMBDA_VALUE * collect_prob) |
|
|
|
total_query_prob += query_word_prob |
|
return total_query_prob / query_size |
|
|
|
def search(self, query_text, query_field_text, max_sim, bm25_results): |
|
origin_scores = [bm25_result.score for bm25_result in bm25_results] |
|
test_docs = [bm25_result.docid for bm25_result in bm25_results] |
|
if (test_docs == []): |
|
print(query_text) |
|
|
|
query_field_text_lst = query_field_text.split(' ') |
|
total_term_freq = self.termfreq_dic['TOTAL'] |
|
collect_probs = {} |
|
for querytoken in query_field_text_lst: |
|
if querytoken in self.termfreq_dic: |
|
collect_probs[querytoken] = max(self.termfreq_dic[querytoken] / total_term_freq, self.MIN_COLLECT_PROB) |
|
else: |
|
collect_probs[querytoken] = self.MIN_COLLECT_PROB |
|
arguments = [( |
|
query_field_text_lst, test_doc, self.object, |
|
self.source_lookup, self.target_lookup, |
|
self.tran, collect_probs, max_sim) |
|
for test_doc in test_docs] |
|
|
|
rank_scores = self.pool.map(self.get_ibm_score, arguments) |
|
return test_docs, rank_scores, origin_scores |
|
|
|
def rerank(self, query_text, query_field_text, baseline, max_sim, tf_table): |
|
test_docs, origin_scores = baseline |
|
if (test_docs == []): |
|
print(query_text) |
|
|
|
query_field_text_lst = query_field_text.split(' ') |
|
total_term_freq = self.termfreq_dic['TOTAL'] |
|
collect_probs = {} |
|
for querytoken in query_field_text_lst: |
|
if querytoken in self.termfreq_dic: |
|
collect_probs[querytoken] = max(self.termfreq_dic[querytoken] / total_term_freq, self.MIN_COLLECT_PROB) |
|
else: |
|
collect_probs[querytoken] = self.MIN_COLLECT_PROB |
|
arguments = [( |
|
query_field_text_lst, test_doc, self.object, |
|
self.source_lookup, self.target_lookup, |
|
self.tran, collect_probs, max_sim) |
|
for test_doc in test_docs] |
|
|
|
rank_scores = self.pool.map(self.get_ibm_score, arguments) |
|
return test_docs, rank_scores, origin_scores |
|
|