ArthurChen189's picture
upload pyserini
62977bb
raw
history blame
2.55 kB
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
from pyserini.encode import DocumentEncoder, QueryEncoder
class DprDocumentEncoder(DocumentEncoder):
def __init__(self, model_name, tokenizer_name=None, device='cuda:0'):
self.device = device
self.model = DPRContextEncoder.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = DPRContextEncoderTokenizer.from_pretrained(tokenizer_name or model_name)
def encode(self, texts, titles=None, max_length=256, **kwargs):
if titles:
inputs = self.tokenizer(
titles,
text_pair=texts,
max_length=max_length,
padding='longest',
truncation=True,
add_special_tokens=True,
return_tensors='pt'
)
else:
inputs = self.tokenizer(
texts,
max_length=max_length,
padding='longest',
truncation=True,
add_special_tokens=True,
return_tensors='pt'
)
inputs.to(self.device)
return self.model(inputs["input_ids"]).pooler_output.detach().cpu().numpy()
class DprQueryEncoder(QueryEncoder):
def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'):
self.device = device
self.model = DPRQuestionEncoder.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_name or model_name)
def encode(self, query: str, **kwargs):
input_ids = self.tokenizer(query, return_tensors='pt')
input_ids.to(self.device)
embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu().numpy()
return embeddings.flatten()