|
|
|
import pickle |
|
from langchain_cohere import CohereRerank |
|
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.retrievers import EnsembleRetriever |
|
|
|
|
|
from .config import FAISS_DB_INDEX, BM25_INDEX |
|
|
|
|
|
def load_bm25_retriever(): |
|
with open(BM25_INDEX, "rb") as f: |
|
bm25_retriever = pickle.load(f) |
|
return bm25_retriever.with_config(run_name="BM25Retriever") |
|
|
|
|
|
def load_faiss_retriever(embeddings): |
|
faiss_db = FAISS.load_local( |
|
FAISS_DB_INDEX, embeddings, allow_dangerous_deserialization=True |
|
) |
|
faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10}) |
|
return faiss_retriever.with_config(run_name="FaissRetriever") |
|
|
|
|
|
def load_retrievers(embeddings): |
|
faiss_retriever = load_faiss_retriever(embeddings) |
|
|
|
bm25_retriever = load_bm25_retriever() |
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[bm25_retriever, faiss_retriever], |
|
weights=[0.7, 0.3], |
|
search_type="mmr", |
|
) |
|
|
|
compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5) |
|
compression_retriever = ContextualCompressionRetriever( |
|
base_compressor=compressor, |
|
base_retriever=ensemble_retriever, |
|
).with_config(run_name="ContextualCompressionRetriever") |
|
|
|
return compression_retriever |
|
|