Spaces:
No application file
No application file
import torch | |
import numpy as np | |
from pyserini.search import QueryEncoder | |
from sentence_transformers import SentenceTransformer | |
class SentenceTransformerEncoder(QueryEncoder): | |
def __init__(self, model_name: str, device: str = 'cpu'): | |
self.device = torch.device(device) | |
self.model = SentenceTransformer(model_name, device=self.device) | |
def encode(self, query: str): | |
emb = self.model.encode(query) | |
emb = emb / np.linalg.norm(emb) | |
# emb = np.expand_dims(emb, axis=0) | |
return emb | |