z00mP's picture
fix rerank bug
d953944
raw
history blame contribute delete
473 Bytes
from sentence_transformers import CrossEncoder
def rerank_documents(ce_model_name, documents, query, top_k_rerank):
top_k_rerank = int(top_k_rerank)
pairs = []
for doc in documents:
pairs.append((query, doc))
ce_model = CrossEncoder(ce_model_name, max_length=512)
scores = ce_model.predict(pairs)
reranked_docs = [x[1] for _, x in sorted(zip(scores, pairs), key=lambda p: p[0], reverse = True)]
return reranked_docs[:top_k_rerank]