ArthurChen189's picture
upload pyserini
62977bb
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module provides Pyserini's Python search interface to Anserini. The main entry point is the ``LuceneSearcher``
class, which wraps the Java class with the same name in Anserini.
"""
import logging
from typing import Dict, List, Optional, Union
from pyserini.fusion import FusionMethod, reciprocal_rank_fusion
from pyserini.index import Document, IndexReader
from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap
from pyserini.search import JQuery, JQueryGenerator
from pyserini.trectools import TrecRun
from pyserini.util import download_prebuilt_index, get_sparse_indexes_info
logger = logging.getLogger(__name__)
# Wrappers around Anserini classes
JLuceneSearcher = autoclass('io.anserini.search.SimpleSearcher')
JLuceneSearcherResult = autoclass('io.anserini.search.SimpleSearcher$Result')
class LuceneSearcher:
"""Wrapper class for ``SimpleSearcher`` in Anserini.
Parameters
----------
index_dir : str
Path to Lucene index directory.
"""
def __init__(self, index_dir: str, prebuilt_index_name=None):
self.index_dir = index_dir
self.object = JLuceneSearcher(index_dir)
self.num_docs = self.object.get_total_num_docs()
# Keep track if self is a known pre-built index.
self.prebuilt_index_name = prebuilt_index_name
@classmethod
def from_prebuilt_index(cls, prebuilt_index_name: str, verbose=False):
"""Build a searcher from a pre-built index; download the index if necessary.
Parameters
----------
prebuilt_index_name : str
Prebuilt index name.
verbose : bool
Print status information.
Returns
-------
LuceneSearcher
Searcher built from the prebuilt index.
"""
if verbose:
print(f'Attempting to initialize pre-built index {prebuilt_index_name}.')
try:
index_dir = download_prebuilt_index(prebuilt_index_name, verbose=verbose)
except ValueError as e:
print(str(e))
return None
# Currently, the only way to validate stats is to create a separate IndexReader, because there is no method
# to obtain the underlying reader of a SimpleSearcher; see https://github.com/castorini/anserini/issues/2013
index_reader = IndexReader(index_dir)
# This is janky as we're created a separate IndexReader for the sole purpose of validating index stats.
index_reader.validate(prebuilt_index_name, verbose=verbose)
if verbose:
print(f'Initializing {prebuilt_index_name}...')
return cls(index_dir, prebuilt_index_name=prebuilt_index_name)
@staticmethod
def list_prebuilt_indexes():
"""Display information about available prebuilt indexes."""
get_sparse_indexes_info()
def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGenerator = None,
fields=dict(), strip_segment_id=False, remove_dups=False) -> List[JLuceneSearcherResult]:
"""Search the collection.
Parameters
----------
q : Union[str, JQuery]
Query string or the ``JQuery`` objected.
k : int
Number of hits to return.
query_generator : JQueryGenerator
Generator to build queries. Set to ``None`` by default to use Anserini default.
fields : dict
Optional map of fields to search with associated boosts.
strip_segment_id : bool
Remove the .XXXXX suffix used to denote different segments from an document.
remove_dups : bool
Remove duplicate docids when writing final run output.
Returns
-------
List[JLuceneSearcherResult]
List of search results.
"""
jfields = JHashMap()
for (field, boost) in fields.items():
jfields.put(field, JFloat(boost))
hits = None
if query_generator:
if not fields:
hits = self.object.search(query_generator, q, k)
else:
hits = self.object.searchFields(query_generator, q, jfields, k)
elif isinstance(q, JQuery):
# Note that RM3 requires the notion of a query (string) to estimate the appropriate models. If we're just
# given a Lucene query, it's unclear what the "query" is for this estimation. One possibility is to extract
# all the query terms from the Lucene query, although this might yield unexpected behavior from the user's
# perspective. Until we think through what exactly is the "right thing to do", we'll raise an exception
# here explicitly.
if self.is_using_rm3():
raise NotImplementedError('RM3 incompatible with search using a Lucene query.')
if fields:
raise NotImplementedError('Cannot specify fields to search when using a Lucene query.')
hits = self.object.search(q, k)
else:
if not fields:
hits = self.object.search(q, k)
else:
hits = self.object.search_fields(q, jfields, k)
docids = set()
filtered_hits = []
for hit in hits:
if strip_segment_id is True:
hit.docid = hit.docid.split('.')[0]
if hit.docid in docids:
continue
filtered_hits.append(hit)
if remove_dups is True:
docids.add(hit.docid)
return filtered_hits
def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1,
query_generator: JQueryGenerator = None, fields = dict()) -> Dict[str, List[JLuceneSearcherResult]]:
"""Search the collection concurrently for multiple queries, using multiple threads.
Parameters
----------
queries : List[str]
List of query strings.
qids : List[str]
List of corresponding query ids.
k : int
Number of hits to return.
threads : int
Maximum number of threads to use.
query_generator : JQueryGenerator
Generator to build queries. Set to ``None`` by default to use Anserini default.
fields : dict
Optional map of fields to search with associated boosts.
Returns
-------
Dict[str, List[JLuceneSearcherResult]]
Dictionary holding the search results, with the query ids as keys and the corresponding lists of search
results as the values.
"""
query_strings = JArrayList()
qid_strings = JArrayList()
for query in queries:
query_strings.add(query)
for qid in qids:
qid_strings.add(qid)
jfields = JHashMap()
for (field, boost) in fields.items():
jfields.put(field, JFloat(boost))
if query_generator:
if not fields:
results = self.object.batch_search(query_generator, query_strings, qid_strings, int(k), int(threads))
else:
results = self.object.batch_search_fields(query_generator, query_strings, qid_strings, int(k), int(threads), jfields)
else:
if not fields:
results = self.object.batch_search(query_strings, qid_strings, int(k), int(threads))
else:
results = self.object.batch_search_fields(query_strings, qid_strings, int(k), int(threads), jfields)
return {r.getKey(): r.getValue() for r in results.entrySet().toArray()}
def get_feedback_terms(self, q: str) -> Dict[str, float]:
"""Returns feedback terms and their weights.
Parameters
----------
q : str
Query string or the ``JQuery`` objected.
Returns
-------
Dict[str, float]
Feedback terms and their weights.
"""
terms_map = self.object.get_feedback_terms(q)
if terms_map:
return {r.getKey(): r.getValue() for r in terms_map.entrySet().toArray()}
else:
return None
def set_analyzer(self, analyzer):
"""Set the Java ``Analyzer`` to use.
Parameters
----------
analyzer : JAnalyzer
Java ``Analyzer`` object.
"""
self.object.set_analyzer(analyzer)
def set_language(self, language):
"""Set language of LuceneSearcher"""
self.object.set_language(language)
def set_rm3(self, fb_terms=10, fb_docs=10, original_query_weight=float(0.5), debug=False, filter_terms=True):
"""Configure RM3 pseudo-relevance feedback.
Parameters
----------
fb_terms : int
RM3 parameter for number of expansion terms.
fb_docs : int
RM3 parameter for number of expansion documents.
original_query_weight : float
RM3 parameter for weight to assign to the original query.
debug : bool
Print the original and expanded queries as debug output.
filter_terms: bool
Whether to remove non-English terms.
"""
if self.object.reader.getTermVectors(0):
self.object.set_rm3(None, fb_terms, fb_docs, original_query_weight, debug, filter_terms)
elif self.prebuilt_index_name in ['msmarco-v1-passage', 'msmarco-v1-doc', 'msmarco-v1-doc-segmented']:
self.object.set_rm3('JsonCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms)
elif self.prebuilt_index_name in ['msmarco-v2-passage', 'msmarco-v2-passage-augmented']:
self.object.set_rm3('MsMarcoV2PassageCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms)
elif self.prebuilt_index_name in ['msmarco-v2-doc', 'msmarco-v2-doc-segmented']:
self.object.set_rm3('MsMarcoV2DocCollection', fb_terms, fb_docs, original_query_weight, debug, filter_terms)
else:
raise TypeError("RM3 is not supported for indexes without document vectors.")
def unset_rm3(self):
"""Disable RM3 pseudo-relevance feedback."""
self.object.unset_rm3()
def is_using_rm3(self) -> bool:
"""Check if RM3 pseudo-relevance feedback is being performed."""
return self.object.use_rm3()
def set_rocchio(self, top_fb_terms=10, top_fb_docs=10, bottom_fb_terms=10, bottom_fb_docs=10,
alpha=1, beta=0.75, gamma=0, debug=False, use_negative=False):
"""Configure Rocchio pseudo-relevance feedback.
Parameters
----------
top_fb_terms : int
Rocchio parameter for number of relevant expansion terms.
top_fb_docs : int
Rocchio parameter for number of relevant expansion documents.
bottom_fb_terms : int
Rocchio parameter for number of non-relevant expansion terms.
bottom_fb_docs : int
Rocchio parameter for number of non-relevant expansion documents.
alpha : float
Rocchio parameter for weight to assign to the original query.
beta: float
Rocchio parameter for weight to assign to the relevant document vector.
gamma: float
Rocchio parameter for weight to assign to the nonrelevant document vector.
debug : bool
Print the original and expanded queries as debug output.
use_negative : bool
Rocchio parameter to use negative labels.
"""
if self.object.reader.getTermVectors(0):
self.object.set_rocchio(None, top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs,
alpha, beta, gamma, debug, use_negative)
elif self.prebuilt_index_name in ['msmarco-v1-passage', 'msmarco-v1-doc', 'msmarco-v1-doc-segmented']:
self.object.set_rocchio('JsonCollection', top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs,
alpha, beta, gamma, debug, use_negative)
# Note, we don't have any Pyserini 2CRs that use Rocchio for MS MARCO v2, so there's currently no
# corresponding code branch here. To avoid introducing bugs (without 2CR tests), we'll add when it's needed.
else:
raise TypeError("Rocchio is not supported for indexes without document vectors.")
def unset_rocchio(self):
"""Disable Rocchio pseudo-relevance feedback."""
self.object.unset_rocchio()
def is_using_rocchio(self) -> bool:
"""Check if Rocchio pseudo-relevance feedback is being performed."""
return self.object.use_rocchio()
def set_qld(self, mu=float(1000)):
"""Configure query likelihood with Dirichlet smoothing as the scoring function.
Parameters
----------
mu : float
Dirichlet smoothing parameter mu.
"""
self.object.set_qld(float(mu))
def set_bm25(self, k1=float(0.9), b=float(0.4)):
"""Configure BM25 as the scoring function.
Parameters
----------
k1 : float
BM25 k1 parameter.
b : float
BM25 b parameter.
"""
self.object.set_bm25(float(k1), float(b))
def get_similarity(self):
"""Return the Lucene ``Similarity`` used as the scoring function."""
return self.object.get_similarity()
def doc(self, docid: Union[str, int]) -> Optional[Document]:
"""Return the :class:`Document` corresponding to ``docid``. The ``docid`` is overloaded: if it is of type
``str``, it is treated as an external collection ``docid``; if it is of type ``int``, it is treated as an
internal Lucene ``docid``. Method returns ``None`` if the ``docid`` does not exist in the index.
Parameters
----------
docid : Union[str, int]
Overloaded ``docid``: either an external collection ``docid`` (``str``) or an internal Lucene ``docid``
(``int``).
Returns
-------
Document
:class:`Document` corresponding to the ``docid``.
"""
lucene_document = self.object.doc(docid)
if lucene_document is None:
return None
return Document(lucene_document)
def batch_doc(self, docids: List[str], threads: int) -> Dict[str, Document]:
"""Concurrently fetching documents for multiple document ids.
Return dictionary that maps ``docid`` to :class:`Document`. Returned dictionary does not
contain ``docid`` if a corresponding :class:`Document` does not exist in the index.
Parameters
----------
docids : List[str]
An external collection ``docid`` (``str``).
threads : int
Maximum number of threads to use.
Returns
-------
Dict[str, Document]
Dictionary mapping the ``docid`` to the corresponding :class:`Document`.
"""
docid_strings = JArrayList()
for docid in docids:
docid_strings.add(docid)
results = self.object.batch_get_docs(docid_strings, threads)
batch_document = {r.getKey(): Document(r.getValue())
for r in results.entrySet().toArray()}
return batch_document
def doc_by_field(self, field: str, q: str) -> Optional[Document]:
"""Return the :class:`Document` based on a ``field`` with ``id``. For example, this method can be used to fetch
document based on alternative primary keys that have been indexed, such as an article's DOI. Method returns
``None`` if no such document exists.
Parameters
----------
field : str
Field to look up.
q : str
Unique id of document.
Returns
-------
Document
:class:`Document` whose ``field`` is ``id``.
"""
lucene_document = self.object.doc_by_field(field, q)
if lucene_document is None:
return None
return Document(lucene_document)
def close(self):
"""Close the searcher."""
self.object.close()
class LuceneSimilarities:
@staticmethod
def bm25(k1=0.9, b=0.4):
return autoclass('org.apache.lucene.search.similarities.BM25Similarity')(k1, b)
@staticmethod
def qld(mu=1000):
return autoclass('org.apache.lucene.search.similarities.LMDirichletSimilarity')(mu)
class LuceneFusionSearcher:
def __init__(self, index_dirs: List[str], method: FusionMethod):
self.method = method
self.searchers = [LuceneSearcher(index_dir) for index_dir in index_dirs]
def get_searchers(self) -> List[LuceneSearcher]:
return self.searchers
def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGenerator = None, strip_segment_id=False, remove_dups=False) -> List[JLuceneSearcherResult]:
trec_runs, docid_to_search_result = list(), dict()
for searcher in self.searchers:
docid_score_pair = list()
hits = searcher.search(q, k=k, query_generator=query_generator,
strip_segment_id=strip_segment_id, remove_dups=remove_dups)
for hit in hits:
docid_to_search_result[hit.docid] = hit
docid_score_pair.append((hit.docid, hit.score))
run = TrecRun.from_search_results(docid_score_pair)
trec_runs.append(run)
if self.method == FusionMethod.RRF:
fused_run = reciprocal_rank_fusion(trec_runs, rrf_k=60, depth=1000, k=1000)
else:
raise NotImplementedError()
return self.convert_to_search_result(fused_run, docid_to_search_result)
@staticmethod
def convert_to_search_result(run: TrecRun, docid_to_search_result: Dict[str, JLuceneSearcherResult]) -> List[JLuceneSearcherResult]:
search_results = []
for _, _, docid, _, score, _ in run.to_numpy():
search_result = docid_to_search_result[docid]
search_result.score = score
search_results.append(search_result)
return search_results