|
import os |
|
import shutil |
|
import sys |
|
from typing import Any, Dict, List, Optional |
|
|
|
import torch |
|
import yaml |
|
from dotenv import load_dotenv |
|
from langchain.chains.base import Chain |
|
from langchain.docstore.document import Document |
|
from langchain.prompts import BasePromptTemplate, load_prompt |
|
from langchain_core.callbacks import CallbackManagerForChainRun |
|
from langchain_core.language_models import BaseLanguageModel |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.retrievers import BaseRetriever |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
kit_dir = os.path.abspath(os.path.join(current_dir, '..')) |
|
repo_dir = os.path.abspath(os.path.join(kit_dir, '..')) |
|
sys.path.append(kit_dir) |
|
sys.path.append(repo_dir) |
|
|
|
|
|
|
|
from utils.model_wrappers.api_gateway import APIGateway |
|
from utils.vectordb.vector_db import VectorDb |
|
from utils.visual.env_utils import get_wandb_key |
|
|
|
CONFIG_PATH = os.path.join(kit_dir, 'config.yaml') |
|
PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db') |
|
|
|
|
|
|
|
|
|
from utils.parsing.sambaparse import parse_doc_universal |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RetrievalQAChain(Chain): |
|
"""class for question-answering.""" |
|
|
|
retriever: BaseRetriever |
|
rerank: bool = True |
|
llm: BaseLanguageModel |
|
qa_prompt: BasePromptTemplate |
|
final_k_retrieved_documents: int = 3 |
|
|
|
@property |
|
def input_keys(self) -> List[str]: |
|
"""Input keys. |
|
:meta private: |
|
""" |
|
return ['question'] |
|
|
|
@property |
|
def output_keys(self) -> List[str]: |
|
"""Output keys. |
|
:meta private: |
|
""" |
|
return ['answer', 'source_documents'] |
|
|
|
def _format_docs(self, docs): |
|
return '\n\n'.join(doc.page_content for doc in docs) |
|
|
|
def rerank_docs(self, query, docs, final_k): |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large') |
|
reranker = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large') |
|
pairs = [] |
|
for d in docs: |
|
pairs.append([query, d.page_content]) |
|
|
|
with torch.no_grad(): |
|
inputs = tokenizer( |
|
pairs, |
|
padding=True, |
|
truncation=True, |
|
return_tensors='pt', |
|
max_length=512, |
|
) |
|
scores = ( |
|
reranker(**inputs, return_dict=True) |
|
.logits.view( |
|
-1, |
|
) |
|
.float() |
|
) |
|
|
|
scores_list = scores.tolist() |
|
scores_sorted_idx = sorted(range(len(scores_list)), key=lambda k: scores_list[k], reverse=True) |
|
|
|
docs_sorted = [docs[k] for k in scores_sorted_idx] |
|
|
|
docs_sorted = docs_sorted[:final_k] |
|
|
|
return docs_sorted |
|
|
|
def _call( |
|
self, |
|
inputs: Dict[str, Any], |
|
run_manager: Optional[CallbackManagerForChainRun] = None, |
|
) -> Dict[str, Any]: |
|
qa_chain = self.qa_prompt | self.llm | StrOutputParser() |
|
response = {} |
|
documents = self.retriever.invoke(inputs['question']) |
|
if self.rerank: |
|
documents = self.rerank_docs(inputs['question'], documents, self.final_k_retrieved_documents) |
|
docs = self._format_docs(documents) |
|
response['answer'] = qa_chain.invoke({'question': inputs['question'], 'context': docs}) |
|
response['source_documents'] = documents |
|
return response |
|
|
|
|
|
class DocumentRetrieval: |
|
def __init__(self): |
|
self.vectordb = VectorDb() |
|
config_info = self.get_config_info() |
|
self.api_info = config_info[0] |
|
self.llm_info = config_info[1] |
|
self.embedding_model_info = config_info[2] |
|
self.retrieval_info = config_info[3] |
|
self.prompts = config_info[4] |
|
self.prod_mode = config_info[5] |
|
self.retriever = None |
|
self.llm = self.set_llm() |
|
|
|
def get_config_info(self): |
|
""" |
|
Loads json config file |
|
""" |
|
|
|
with open(CONFIG_PATH, 'r') as yaml_file: |
|
config = yaml.safe_load(yaml_file) |
|
api_info = config['api'] |
|
llm_info = config['llm'] |
|
embedding_model_info = config['embedding_model'] |
|
retrieval_info = config['retrieval'] |
|
prompts = config['prompts'] |
|
prod_mode = config['prod_mode'] |
|
|
|
return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode |
|
|
|
def set_llm(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') |
|
|
|
llm = APIGateway.load_llm( |
|
type=self.api_info, |
|
streaming=True, |
|
coe=self.llm_info['coe'], |
|
do_sample=self.llm_info['do_sample'], |
|
max_tokens_to_generate=self.llm_info['max_tokens_to_generate'], |
|
temperature=self.llm_info['temperature'], |
|
select_expert=self.llm_info['select_expert'], |
|
process_prompt=False, |
|
sambanova_api_key=sambanova_api_key, |
|
) |
|
return llm |
|
|
|
def parse_doc(self, docs: List, additional_metadata: Optional[Dict] = None) -> List[Document]: |
|
""" |
|
Parse the uploaded documents and return a list of LangChain documents. |
|
|
|
Args: |
|
docs (List[UploadFile]): A list of uploaded files. |
|
additional_metadata (Optional[Dict], optional): Additional metadata to include in the processed documents. |
|
Defaults to an empty dictionary. |
|
|
|
Returns: |
|
List[Document]: A list of LangChain documents. |
|
""" |
|
if additional_metadata is None: |
|
additional_metadata = {} |
|
|
|
|
|
temp_folder = os.path.join(kit_dir, 'data/tmp') |
|
if not os.path.exists(temp_folder): |
|
os.makedirs(temp_folder) |
|
else: |
|
|
|
for filename in os.listdir(temp_folder): |
|
file_path = os.path.join(temp_folder, filename) |
|
try: |
|
if os.path.isfile(file_path) or os.path.islink(file_path): |
|
os.unlink(file_path) |
|
elif os.path.isdir(file_path): |
|
shutil.rmtree(file_path) |
|
except Exception as e: |
|
print(f'Failed to delete {file_path}. Reason: {e}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for doc_info in docs: |
|
file_name, file_obj = doc_info |
|
temp_file = os.path.join(temp_folder, file_name) |
|
with open(temp_file, 'wb') as f: |
|
f.write(file_obj.read()) |
|
|
|
|
|
_, _, langchain_docs = parse_doc_universal(doc=temp_folder, additional_metadata=additional_metadata) |
|
return langchain_docs |
|
|
|
def load_embedding_model(self): |
|
embeddings = APIGateway.load_embedding_model( |
|
type=self.embedding_model_info['type'], |
|
batch_size=self.embedding_model_info['batch_size'], |
|
coe=self.embedding_model_info['coe'], |
|
select_expert=self.embedding_model_info['select_expert'], |
|
) |
|
return embeddings |
|
|
|
def create_vector_store(self, text_chunks, embeddings, output_db=None, collection_name=None): |
|
print(f'Collection name is {collection_name}') |
|
vectorstore = self.vectordb.create_vector_store( |
|
text_chunks, embeddings, output_db=output_db, collection_name=collection_name, db_type='chroma' |
|
) |
|
return vectorstore |
|
|
|
def load_vdb(self, db_path, embeddings, collection_name=None): |
|
print(f'Loading collection name is {collection_name}') |
|
vectorstore = self.vectordb.load_vdb(db_path, embeddings, db_type='chroma', collection_name=collection_name) |
|
return vectorstore |
|
|
|
def init_retriever(self, vectorstore): |
|
if self.retrieval_info['rerank']: |
|
self.retriever = vectorstore.as_retriever( |
|
search_type='similarity_score_threshold', |
|
search_kwargs={ |
|
'score_threshold': self.retrieval_info['score_threshold'], |
|
'k': self.retrieval_info['k_retrieved_documents'], |
|
}, |
|
) |
|
else: |
|
self.retriever = vectorstore.as_retriever( |
|
search_type='similarity_score_threshold', |
|
search_kwargs={ |
|
'score_threshold': self.retrieval_info['score_threshold'], |
|
'k': self.retrieval_info['final_k_retrieved_documents'], |
|
}, |
|
) |
|
|
|
def get_qa_retrieval_chain(self): |
|
""" |
|
Generate a qa_retrieval chain using a language model. |
|
|
|
This function uses a language model, specifically a SambaNova LLM, to generate a qa_retrieval chain |
|
based on the input vector store of text chunks. |
|
|
|
Parameters: |
|
vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context |
|
for generating the conversation chain. |
|
|
|
Returns: |
|
RetrievalQA: A chain ready for QA without memory |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrievalQAChain = RetrievalQAChain( |
|
retriever=self.retriever, |
|
llm=self.llm, |
|
qa_prompt=load_prompt(os.path.join(kit_dir, self.prompts['qa_prompt'])), |
|
rerank=self.retrieval_info['rerank'], |
|
final_k_retrieved_documents=self.retrieval_info['final_k_retrieved_documents'], |
|
) |
|
return retrievalQAChain |
|
|
|
def get_conversational_qa_retrieval_chain(self): |
|
""" |
|
Generate a conversational retrieval qa chain using a language model. |
|
|
|
This function uses a language model, specifically a SambaNova LLM, to generate a conversational_qa_retrieval chain |
|
based on the chat history and the relevant retrieved content from the input vector store of text chunks. |
|
|
|
Parameters: |
|
vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context |
|
for generating the conversation chain. |
|
|
|
Returns: |
|
RetrievalQA: A chain ready for QA with memory |
|
""" |
|
|