|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This module provides Pyserini's Python search interface to Anserini. The main entry point is the ``LuceneSearcher`` |
|
class, which wraps the Java class with the same name in Anserini. |
|
""" |
|
|
|
import logging |
|
from typing import Dict, List, Optional, Union |
|
|
|
from pyserini.fusion import FusionMethod, reciprocal_rank_fusion |
|
from pyserini.index import Document, IndexReader |
|
from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap |
|
from pyserini.search import JQuery, JQueryGenerator |
|
from pyserini.trectools import TrecRun |
|
from pyserini.util import download_prebuilt_index, get_sparse_indexes_info |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
JLuceneSearcher = autoclass('io.anserini.search.SimpleSearcher') |
|
JLuceneSearcherResult = autoclass('io.anserini.search.SimpleSearcher$Result') |
|
|
|
|
|
class LuceneSearcher: |
|
"""Wrapper class for ``SimpleSearcher`` in Anserini. |
|
|
|
Parameters |
|
---------- |
|
index_dir : str |
|
Path to Lucene index directory. |
|
""" |
|
|
|
def __init__(self, index_dir: str, prebuilt_index_name=None): |
|
self.index_dir = index_dir |
|
self.object = JLuceneSearcher(index_dir) |
|
self.num_docs = self.object.get_total_num_docs() |
|
|
|
self.prebuilt_index_name = prebuilt_index_name |
|
|
|
@classmethod |
|
def from_prebuilt_index(cls, prebuilt_index_name: str, verbose=False): |
|
"""Build a searcher from a pre-built index; download the index if necessary. |
|
|
|
Parameters |
|
---------- |
|
prebuilt_index_name : str |
|
Prebuilt index name. |
|
verbose : bool |
|
Print status information. |
|
|
|
Returns |
|
------- |
|
LuceneSearcher |
|
Searcher built from the prebuilt index. |
|
""" |
|
if verbose: |
|
print(f'Attempting to initialize pre-built index {prebuilt_index_name}.') |
|
|
|
try: |
|
index_dir = download_prebuilt_index(prebuilt_index_name, verbose=verbose) |
|
except ValueError as e: |
|
print(str(e)) |
|
return None |
|
|
|
|
|
|
|
index_reader = IndexReader(index_dir) |
|
|
|
index_reader.validate(prebuilt_index_name, verbose=verbose) |
|
|
|
if verbose: |
|
print(f'Initializing {prebuilt_index_name}...') |
|
|
|
return cls(index_dir, prebuilt_index_name=prebuilt_index_name) |
|
|
|
@staticmethod |
|
def list_prebuilt_indexes(): |
|
"""Display information about available prebuilt indexes.""" |
|
get_sparse_indexes_info() |
|
|
|
def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGenerator = None, |
|
fields=dict(), strip_segment_id=False, remove_dups=False) -> List[JLuceneSearcherResult]: |
|
"""Search the collection. |
|
|
|
Parameters |
|
---------- |
|
q : Union[str, JQuery] |
|
Query string or the ``JQuery`` objected. |
|
k : int |
|
Number of hits to return. |
|
query_generator : JQueryGenerator |
|
Generator to build queries. Set to ``None`` by default to use Anserini default. |
|
fields : dict |
|
Optional map of fields to search with associated boosts. |
|
strip_segment_id : bool |
|
Remove the .XXXXX suffix used to denote different segments from an document. |
|
remove_dups : bool |
|
Remove duplicate docids when writing final run output. |
|
|
|
Returns |
|
------- |
|
List[JLuceneSearcherResult] |
|
List of search results. |
|
""" |
|
|
|
jfields = JHashMap() |
|
for (field, boost) in fields.items(): |
|
jfields.put(field, JFloat(boost)) |
|
|
|
hits = None |
|
if query_generator: |
|
if not fields: |
|
hits = self.object.search(query_generator, q, k) |
|
else: |
|
hits = self.object.searchFields(query_generator, q, jfields, k) |
|
elif isinstance(q, JQuery): |
|
|
|
|
|
|
|
|
|
|
|
if self.is_using_rm3(): |
|
raise NotImplementedError('RM3 incompatible with search using a Lucene query.') |
|
if fields: |
|
raise NotImplementedError('Cannot specify fields to search when using a Lucene query.') |
|
hits = self.object.search(q, k) |
|
else: |
|
if not fields: |
|
hits = self.object.search(q, k) |
|
else: |
|
hits = self.object.search_fields(q, jfields, k) |
|
|
|
docids = set() |
|
filtered_hits = [] |
|
|
|
for hit in hits: |
|
if strip_segment_id is True: |
|
hit.docid = hit.docid.split('.')[0] |
|
|
|
if hit.docid in docids: |
|
continue |
|
|
|
filtered_hits.append(hit) |
|
|
|
if remove_dups is True: |
|
docids.add(hit.docid) |
|
|
|
return filtered_hits |
|
|
|
def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1, |
|
query_generator: JQueryGenerator = None, fields = dict()) -> Dict[str, List[JLuceneSearcherResult]]: |
|
"""Search the collection concurrently for multiple queries, using multiple threads. |
|
|
|
Parameters |
|
---------- |
|
queries : List[str] |
|
List of query strings. |
|
qids : List[str] |
|
List of corresponding query ids. |
|
k : int |
|
Number of hits to return. |
|
threads : int |
|
Maximum number of threads to use. |
|
query_generator : JQueryGenerator |
|
Generator to build queries. Set to ``None`` by default to use Anserini default. |
|
fields : dict |
|
Optional map of fields to search with associated boosts. |
|
|
|
Returns |
|
------- |
|
Dict[str, List[JLuceneSearcherResult]] |
|
Dictionary holding the search results, with the query ids as keys and the corresponding lists of search |
|
results as the values. |
|
""" |
|
query_strings = JArrayList() |
|
qid_strings = JArrayList() |
|
for query in queries: |
|
query_strings.add(query) |
|
|
|
for qid in qids: |
|
qid_strings.add(qid) |
|
|
|
jfields = JHashMap() |
|
for (field, boost) in fields.items(): |
|
jfields.put(field, JFloat(boost)) |
|
|
|
if query_generator: |
|
if not fields: |
|
results = self.object.batch_search(query_generator, query_strings, qid_strings, int(k), int(threads)) |
|
else: |
|
results = self.object.batch_search_fields(query_generator, query_strings, qid_strings, int(k), int(threads), jfields) |
|
else: |
|
if not fields: |
|
results = self.object.batch_search(query_strings, qid_strings, int(k), int(threads)) |
|
else: |
|
results = self.object.batch_search_fields(query_strings, qid_strings, int(k), int(threads), jfields) |
|
return {r.getKey(): r.getValue() for r in results.entrySet().toArray()} |
|
|
|
def get_feedback_terms(self, q: str) -> Dict[str, float]: |
|
"""Returns feedback terms and their weights. |
|
|
|
Parameters |
|
---------- |
|
q : str |
|
Query string or the ``JQuery`` objected. |
|
|
|
Returns |
|
------- |
|
Dict[str, float] |
|
Feedback terms and their weights. |
|
""" |
|
|
|
terms_map = self.object.get_feedback_terms(q) |
|
if terms_map: |
|
return {r.getKey(): r.getValue() for r in terms_map.entrySet().toArray()} |
|
else: |
|
return None |
|
|
|
def set_analyzer(self, analyzer): |
|
"""Set the Java ``Analyzer`` to use. |
|
|
|
Parameters |
|
---------- |
|
analyzer : JAnalyzer |
|
Java ``Analyzer`` object. |
|
""" |
|
self.object.set_analyzer(analyzer) |
|
|
|
def set_language(self, language): |
|
"""Set language of LuceneSearcher""" |
|
self.object.set_language(language) |
|
|
|
def set_rm3(self, fb_terms=10, fb_docs=10, original_query_weight=float(0.5), debug=False, filter_terms=True): |
|
"""Configure RM3 pseudo-relevance feedback. |
|
|
|
Parameters |
|
---------- |
|
fb_terms : int |
|
RM3 parameter for number of expansion terms. |
|
fb_docs : int |
|
RM3 parameter for number of expansion documents. |
|
original_query_weight : float |
|
RM3 parameter for weight to assign to the original query. |
|
debug : bool |
|
Print the original and expanded queries as debug output. |
|
filter_terms: bool |
|
Whether to remove non-English terms. |
|
""" |
|
if self.object.reader.getTermVectors(0): |
|
self.object.set_rm3(None, fb_terms, fb_docs, original_query_weight, debug, filter_terms) |
|
elif self.prebuilt_index_name in ['msmarco-v1-passage', 'msmarco-v1-doc', 'msmarco-v1-doc-segmented']: |
|
self.object.set_rm3('JsonCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms) |
|
elif self.prebuilt_index_name in ['msmarco-v2-passage', 'msmarco-v2-passage-augmented']: |
|
self.object.set_rm3('MsMarcoV2PassageCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms) |
|
elif self.prebuilt_index_name in ['msmarco-v2-doc', 'msmarco-v2-doc-segmented']: |
|
self.object.set_rm3('MsMarcoV2DocCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms) |
|
else: |
|
raise TypeError("RM3 is not supported for indexes without document vectors.") |
|
|
|
def unset_rm3(self): |
|
"""Disable RM3 pseudo-relevance feedback.""" |
|
self.object.unset_rm3() |
|
|
|
def is_using_rm3(self) -> bool: |
|
"""Check if RM3 pseudo-relevance feedback is being performed.""" |
|
return self.object.use_rm3() |
|
|
|
def set_rocchio(self, top_fb_terms=10, top_fb_docs=10, bottom_fb_terms=10, bottom_fb_docs=10, |
|
alpha=1, beta=0.75, gamma=0, debug=False, use_negative=False): |
|
"""Configure Rocchio pseudo-relevance feedback. |
|
|
|
Parameters |
|
---------- |
|
top_fb_terms : int |
|
Rocchio parameter for number of relevant expansion terms. |
|
top_fb_docs : int |
|
Rocchio parameter for number of relevant expansion documents. |
|
bottom_fb_terms : int |
|
Rocchio parameter for number of non-relevant expansion terms. |
|
bottom_fb_docs : int |
|
Rocchio parameter for number of non-relevant expansion documents. |
|
alpha : float |
|
Rocchio parameter for weight to assign to the original query. |
|
beta: float |
|
Rocchio parameter for weight to assign to the relevant document vector. |
|
gamma: float |
|
Rocchio parameter for weight to assign to the nonrelevant document vector. |
|
debug : bool |
|
Print the original and expanded queries as debug output. |
|
use_negative : bool |
|
Rocchio parameter to use negative labels. |
|
""" |
|
if self.object.reader.getTermVectors(0): |
|
self.object.set_rocchio(None, top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs, |
|
alpha, beta, gamma, debug, use_negative) |
|
elif self.prebuilt_index_name in ['msmarco-v1-passage', 'msmarco-v1-doc', 'msmarco-v1-doc-segmented']: |
|
self.object.set_rocchio('JsonCollection', top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs, |
|
alpha, beta, gamma, debug, use_negative) |
|
|
|
|
|
else: |
|
raise TypeError("Rocchio is not supported for indexes without document vectors.") |
|
|
|
def unset_rocchio(self): |
|
"""Disable Rocchio pseudo-relevance feedback.""" |
|
self.object.unset_rocchio() |
|
|
|
def is_using_rocchio(self) -> bool: |
|
"""Check if Rocchio pseudo-relevance feedback is being performed.""" |
|
return self.object.use_rocchio() |
|
|
|
def set_qld(self, mu=float(1000)): |
|
"""Configure query likelihood with Dirichlet smoothing as the scoring function. |
|
|
|
Parameters |
|
---------- |
|
mu : float |
|
Dirichlet smoothing parameter mu. |
|
""" |
|
self.object.set_qld(float(mu)) |
|
|
|
def set_bm25(self, k1=float(0.9), b=float(0.4)): |
|
"""Configure BM25 as the scoring function. |
|
|
|
Parameters |
|
---------- |
|
k1 : float |
|
BM25 k1 parameter. |
|
b : float |
|
BM25 b parameter. |
|
""" |
|
self.object.set_bm25(float(k1), float(b)) |
|
|
|
def get_similarity(self): |
|
"""Return the Lucene ``Similarity`` used as the scoring function.""" |
|
return self.object.get_similarity() |
|
|
|
def doc(self, docid: Union[str, int]) -> Optional[Document]: |
|
"""Return the :class:`Document` corresponding to ``docid``. The ``docid`` is overloaded: if it is of type |
|
``str``, it is treated as an external collection ``docid``; if it is of type ``int``, it is treated as an |
|
internal Lucene ``docid``. Method returns ``None`` if the ``docid`` does not exist in the index. |
|
|
|
Parameters |
|
---------- |
|
docid : Union[str, int] |
|
Overloaded ``docid``: either an external collection ``docid`` (``str``) or an internal Lucene ``docid`` |
|
(``int``). |
|
|
|
Returns |
|
------- |
|
Document |
|
:class:`Document` corresponding to the ``docid``. |
|
""" |
|
lucene_document = self.object.doc(docid) |
|
if lucene_document is None: |
|
return None |
|
return Document(lucene_document) |
|
|
|
def batch_doc(self, docids: List[str], threads: int) -> Dict[str, Document]: |
|
"""Concurrently fetching documents for multiple document ids. |
|
Return dictionary that maps ``docid`` to :class:`Document`. Returned dictionary does not |
|
contain ``docid`` if a corresponding :class:`Document` does not exist in the index. |
|
|
|
Parameters |
|
---------- |
|
docids : List[str] |
|
An external collection ``docid`` (``str``). |
|
threads : int |
|
Maximum number of threads to use. |
|
|
|
Returns |
|
------- |
|
Dict[str, Document] |
|
Dictionary mapping the ``docid`` to the corresponding :class:`Document`. |
|
""" |
|
docid_strings = JArrayList() |
|
for docid in docids: |
|
docid_strings.add(docid) |
|
|
|
results = self.object.batch_get_docs(docid_strings, threads) |
|
batch_document = {r.getKey(): Document(r.getValue()) |
|
for r in results.entrySet().toArray()} |
|
return batch_document |
|
|
|
def doc_by_field(self, field: str, q: str) -> Optional[Document]: |
|
"""Return the :class:`Document` based on a ``field`` with ``id``. For example, this method can be used to fetch |
|
document based on alternative primary keys that have been indexed, such as an article's DOI. Method returns |
|
``None`` if no such document exists. |
|
|
|
Parameters |
|
---------- |
|
field : str |
|
Field to look up. |
|
q : str |
|
Unique id of document. |
|
|
|
Returns |
|
------- |
|
Document |
|
:class:`Document` whose ``field`` is ``id``. |
|
""" |
|
lucene_document = self.object.doc_by_field(field, q) |
|
if lucene_document is None: |
|
return None |
|
return Document(lucene_document) |
|
|
|
def close(self): |
|
"""Close the searcher.""" |
|
self.object.close() |
|
|
|
|
|
class LuceneSimilarities: |
|
@staticmethod |
|
def bm25(k1=0.9, b=0.4): |
|
return autoclass('org.apache.lucene.search.similarities.BM25Similarity')(k1, b) |
|
|
|
@staticmethod |
|
def qld(mu=1000): |
|
return autoclass('org.apache.lucene.search.similarities.LMDirichletSimilarity')(mu) |
|
|
|
|
|
class LuceneFusionSearcher: |
|
def __init__(self, index_dirs: List[str], method: FusionMethod): |
|
self.method = method |
|
self.searchers = [LuceneSearcher(index_dir) for index_dir in index_dirs] |
|
|
|
def get_searchers(self) -> List[LuceneSearcher]: |
|
return self.searchers |
|
|
|
def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGenerator = None, strip_segment_id=False, remove_dups=False) -> List[JLuceneSearcherResult]: |
|
trec_runs, docid_to_search_result = list(), dict() |
|
|
|
for searcher in self.searchers: |
|
docid_score_pair = list() |
|
hits = searcher.search(q, k=k, query_generator=query_generator, |
|
strip_segment_id=strip_segment_id, remove_dups=remove_dups) |
|
|
|
for hit in hits: |
|
docid_to_search_result[hit.docid] = hit |
|
docid_score_pair.append((hit.docid, hit.score)) |
|
|
|
run = TrecRun.from_search_results(docid_score_pair) |
|
trec_runs.append(run) |
|
|
|
if self.method == FusionMethod.RRF: |
|
fused_run = reciprocal_rank_fusion(trec_runs, rrf_k=60, depth=1000, k=1000) |
|
else: |
|
raise NotImplementedError() |
|
|
|
return self.convert_to_search_result(fused_run, docid_to_search_result) |
|
|
|
@staticmethod |
|
def convert_to_search_result(run: TrecRun, docid_to_search_result: Dict[str, JLuceneSearcherResult]) -> List[JLuceneSearcherResult]: |
|
search_results = [] |
|
|
|
for _, _, docid, _, score, _ in run.to_numpy(): |
|
search_result = docid_to_search_result[docid] |
|
search_result.score = score |
|
search_results.append(search_result) |
|
|
|
return search_results |
|
|