Spaces:
Runtime error
Runtime error
# | |
# Pyserini: Reproducible IR research with sparse and dense representations | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import json | |
import os | |
from abc import ABC, abstractmethod | |
from enum import Enum, unique | |
from typing import List | |
from pyserini.search import JLuceneSearcherResult | |
class OutputFormat(Enum): | |
TREC = 'trec' | |
MSMARCO = "msmarco" | |
KILT = 'kilt' | |
class OutputWriter(ABC): | |
def __init__(self, file_path: str, mode: str = 'w', | |
max_hits: int = 1000, tag: str = None, topics: dict = None, | |
use_max_passage: bool = False, max_passage_delimiter: str = None, max_passage_hits: int = 100): | |
self.file_path = file_path | |
self.mode = mode | |
self.tag = tag | |
self.topics = topics | |
self.use_max_passage = use_max_passage | |
self.max_passage_delimiter = max_passage_delimiter if use_max_passage else None | |
self.max_hits = max_passage_hits if use_max_passage else max_hits | |
self._file = None | |
def __enter__(self): | |
dirname = os.path.dirname(self.file_path) | |
if dirname: | |
os.makedirs(dirname, exist_ok=True) | |
self._file = open(self.file_path, self.mode) | |
return self | |
def __exit__(self, exc_type, exc_value, exc_traceback): | |
self._file.close() | |
def hits_iterator(self, hits: List[JLuceneSearcherResult]): | |
unique_docs = set() | |
rank = 1 | |
for hit in hits: | |
if self.use_max_passage and self.max_passage_delimiter: | |
docid = hit.docid.split(self.max_passage_delimiter)[0] | |
else: | |
docid = hit.docid.strip() | |
if self.use_max_passage: | |
if docid in unique_docs: | |
continue | |
unique_docs.add(docid) | |
yield docid, rank, hit.score, hit | |
rank = rank + 1 | |
if rank > self.max_hits: | |
break | |
def write(self, topic: str, hits: List[JLuceneSearcherResult]): | |
raise NotImplementedError() | |
class TrecWriter(OutputWriter): | |
def write(self, topic: str, hits: List[JLuceneSearcherResult]): | |
for docid, rank, score, _ in self.hits_iterator(hits): | |
self._file.write(f'{topic} Q0 {docid} {rank} {score:.6f} {self.tag}\n') | |
class MsMarcoWriter(OutputWriter): | |
def write(self, topic: str, hits: List[JLuceneSearcherResult]): | |
for docid, rank, score, _ in self.hits_iterator(hits): | |
self._file.write(f'{topic}\t{docid}\t{rank}\n') | |
class KiltWriter(OutputWriter): | |
def write(self, topic: str, hits: List[JLuceneSearcherResult]): | |
datapoint = self.topics[topic] | |
provenance = [] | |
for docid, rank, score, _ in self.hits_iterator(hits): | |
provenance.append({"wikipedia_id": docid}) | |
datapoint["output"] = [{"provenance": provenance}] | |
json.dump(datapoint, self._file) | |
self._file.write('\n') | |
def get_output_writer(file_path: str, output_format: OutputFormat, *args, **kwargs) -> OutputWriter: | |
mapping = { | |
OutputFormat.TREC: TrecWriter, | |
OutputFormat.MSMARCO: MsMarcoWriter, | |
OutputFormat.KILT: KiltWriter, | |
} | |
return mapping[output_format](file_path, *args, **kwargs) | |
def tie_breaker(hits): | |
return sorted(hits, key=lambda x: (-x.score, x.docid)) | |