|
import os |
|
import sys |
|
import argparse |
|
import pandas as pd |
|
import time |
|
from typing import Any, Dict, Optional |
|
from langchain_core.callbacks import CallbackManagerForChainRun |
|
from langchain.prompts import load_prompt |
|
from langchain_core.output_parsers import StrOutputParser |
|
from transformers import AutoTokenizer |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
kit_dir = os.path.abspath(os.path.join(current_dir, "..")) |
|
repo_dir = os.path.abspath(os.path.join(kit_dir, "..")) |
|
|
|
sys.path.append(kit_dir) |
|
sys.path.append(repo_dir) |
|
|
|
from enterprise_knowledge_retriever.src.document_retrieval import DocumentRetrieval, RetrievalQAChain |
|
|
|
class TimedRetrievalQAChain(RetrievalQAChain): |
|
|
|
def _call(self, |
|
inputs: Dict[str, Any], |
|
run_manager: Optional[CallbackManagerForChainRun] = None, |
|
) -> Dict[str, Any]: |
|
qa_chain = self.qa_prompt | self.llm | StrOutputParser() |
|
response = {} |
|
start_time = time.time() |
|
documents = self.retriever.invoke(inputs["question"]) |
|
if self.rerank: |
|
documents = self.rerank_docs(inputs["question"], documents, self.final_k_retrieved_documents) |
|
docs = self._format_docs(documents) |
|
end_preprocessing_time=time.time() |
|
response["answer"] = qa_chain.invoke({"question": inputs["question"], "context": docs}) |
|
end_llm_time=time.time() |
|
response["source_documents"] = documents |
|
response["start_time"] = start_time |
|
response["end_preprocessing_time"] = end_preprocessing_time |
|
response["end_llm_time"] = end_llm_time |
|
return response |
|
|
|
def analyze_times(answer, start_time, end_preprocessing_time, end_llm_time, tokenizer): |
|
preprocessing_time=end_preprocessing_time-start_time |
|
llm_time=end_llm_time-end_preprocessing_time |
|
token_count=len(tokenizer.encode(answer)) |
|
tokens_per_second = token_count / llm_time |
|
perf = {"preprocessing_time": preprocessing_time, |
|
"llm_time": llm_time, |
|
"token_count": token_count, |
|
"tokens_per_second": tokens_per_second} |
|
return perf |
|
|
|
def generate(qa_chain, question, tokenizer): |
|
response = qa_chain.invoke({"question": question}) |
|
answer = response.get('answer') |
|
sources = set([ |
|
f'{sd.metadata["filename"]}' |
|
for sd in response["source_documents"] |
|
]) |
|
times = analyze_times( |
|
answer, |
|
response.get("start_time"), |
|
response.get("end_preprocessing_time"), |
|
response.get("end_llm_time"), |
|
tokenizer |
|
) |
|
return answer, sources, times |
|
|
|
def process_bulk_QA(vectordb_path, questions_file_path): |
|
documentRetrieval = DocumentRetrieval() |
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") |
|
if os.path.exists(vectordb_path): |
|
|
|
embeddings = documentRetrieval.load_embedding_model() |
|
vectorstore = documentRetrieval.load_vdb(vectordb_path, embeddings) |
|
print("Database loaded") |
|
documentRetrieval.init_retriever(vectorstore) |
|
print("retriever initialized") |
|
|
|
qa_chain = TimedRetrievalQAChain( |
|
retriever=documentRetrieval.retriever, |
|
llm=documentRetrieval.llm, |
|
qa_prompt = load_prompt(os.path.join(kit_dir, documentRetrieval.prompts["qa_prompt"])), |
|
rerank = documentRetrieval.retrieval_info["rerank"], |
|
final_k_retrieved_documents = documentRetrieval.retrieval_info["final_k_retrieved_documents"] |
|
|
|
) |
|
else: |
|
raise f"vector db path {vectordb_path} does not exist" |
|
if os.path.exists(questions_file_path): |
|
df = pd.read_excel(questions_file_path) |
|
print(df) |
|
output_file_path = questions_file_path.replace('.xlsx', '_output.xlsx') |
|
if 'Answer' not in df.columns: |
|
df['Answer'] = '' |
|
df['Sources'] = '' |
|
df['preprocessing_time'] = '' |
|
df['llm_time'] = '' |
|
df['token_count'] = '' |
|
df['tokens_per_second'] = '' |
|
for index, row in df.iterrows(): |
|
if row['Answer'].strip()=='': |
|
try: |
|
|
|
print(f"Generating answer for row {index}") |
|
answer, sources, times = generate(qa_chain, row['Questions'], tokenizer) |
|
df.at[index, 'Answer'] = answer |
|
df.at[index, 'Sources'] = sources |
|
df.at[index, 'preprocessing_time'] = times.get("preprocessing_time") |
|
df.at[index, 'llm_time'] = times.get("llm_time") |
|
df.at[index, 'token_count'] = times.get("token_count") |
|
df.at[index, 'tokens_per_second'] = times.get("tokens_per_second") |
|
except Exception as e: |
|
print(f"Error processing row {index}: {e}") |
|
|
|
df.to_excel(output_file_path, index=False) |
|
else: |
|
print(f"Skipping row {index} because 'Answer' is already in the document") |
|
return output_file_path |
|
else: |
|
raise f"questions file path {questions_file_path} does not exist" |
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser(description='use a vectordb and an excel file with questions in the first column and generate answers for all the questions') |
|
parser.add_argument('vectordb_path', type=str, help='vector db path with stored documents for RAG') |
|
parser.add_argument('questions_path', type=str, help='xlsx file containing questions in a column named Questions') |
|
args = parser.parse_args() |
|
|
|
out_file = process_bulk_QA(args.vectordb_path, args.questions_path) |
|
print(f"Finished, responses in: {out_file}") |