|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
if torch.cuda.is_available(): |
|
from torch.cuda.amp import autocast |
|
from transformers import BertModel, BertTokenizer, BertTokenizerFast |
|
|
|
from pyserini.encode import DocumentEncoder, QueryEncoder |
|
from onnxruntime import ExecutionMode, SessionOptions, InferenceSession |
|
|
|
|
|
class TctColBertDocumentEncoder(DocumentEncoder): |
|
def __init__(self, model_name: str, tokenizer_name=None, device='cuda:0'): |
|
self.device = device |
|
self.onnx = False |
|
if model_name.endswith('onnx'): |
|
options = SessionOptions() |
|
self.session = InferenceSession(model_name, options) |
|
self.onnx = True |
|
self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name or model_name[:-5]) |
|
else: |
|
self.model = BertModel.from_pretrained(model_name) |
|
self.model.to(self.device) |
|
self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name or model_name) |
|
|
|
def encode(self, texts, titles=None, fp16=False, max_length=512, **kwargs): |
|
if titles is not None: |
|
texts = [f'[CLS] [D] {title} {text}' for title, text in zip(titles, texts)] |
|
else: |
|
texts = ['[CLS] [D] ' + text for text in texts] |
|
inputs = self.tokenizer( |
|
texts, |
|
max_length=max_length, |
|
padding="longest", |
|
truncation=True, |
|
add_special_tokens=False, |
|
return_tensors='pt' |
|
) |
|
if self.onnx: |
|
inputs_onnx = {name: np.atleast_2d(value) for name, value in inputs.items()} |
|
inputs.to(self.device) |
|
outputs, _ = self.session.run(None, inputs_onnx) |
|
outputs = torch.from_numpy(outputs).to(self.device) |
|
embeddings = self._mean_pooling(outputs[:, 4:, :], inputs['attention_mask'][:, 4:]) |
|
else: |
|
inputs.to(self.device) |
|
if fp16: |
|
with autocast(): |
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
else: |
|
outputs = self.model(**inputs) |
|
embeddings = self._mean_pooling(outputs["last_hidden_state"][:, 4:, :], inputs['attention_mask'][:, 4:]) |
|
return embeddings.detach().cpu().numpy() |
|
|
|
|
|
class TctColBertQueryEncoder(QueryEncoder): |
|
def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'): |
|
self.device = device |
|
self.model = BertModel.from_pretrained(model_name) |
|
self.model.to(self.device) |
|
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or model_name) |
|
|
|
def encode(self, query: str, **kwargs): |
|
max_length = 36 |
|
inputs = self.tokenizer( |
|
'[CLS] [Q] ' + query + '[MASK]' * max_length, |
|
max_length=max_length, |
|
truncation=True, |
|
add_special_tokens=False, |
|
return_tensors='pt' |
|
) |
|
inputs.to(self.device) |
|
outputs = self.model(**inputs) |
|
embeddings = outputs.last_hidden_state.detach().cpu().numpy() |
|
return np.average(embeddings[:, 4:, :], axis=-2).flatten() |
|
|