petrojm's picture
add EKR files
a6c26b1
raw
history blame
5.99 kB
import os
import sys
import argparse
import pandas as pd
import time
from typing import Any, Dict, Optional
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain.prompts import load_prompt
from langchain_core.output_parsers import StrOutputParser
from transformers import AutoTokenizer
current_dir = os.path.dirname(os.path.abspath(__file__))
kit_dir = os.path.abspath(os.path.join(current_dir, ".."))
repo_dir = os.path.abspath(os.path.join(kit_dir, ".."))
sys.path.append(kit_dir)
sys.path.append(repo_dir)
from enterprise_knowledge_retriever.src.document_retrieval import DocumentRetrieval, RetrievalQAChain
class TimedRetrievalQAChain(RetrievalQAChain):
#override call method to return times
def _call(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
qa_chain = self.qa_prompt | self.llm | StrOutputParser()
response = {}
start_time = time.time()
documents = self.retriever.invoke(inputs["question"])
if self.rerank:
documents = self.rerank_docs(inputs["question"], documents, self.final_k_retrieved_documents)
docs = self._format_docs(documents)
end_preprocessing_time=time.time()
response["answer"] = qa_chain.invoke({"question": inputs["question"], "context": docs})
end_llm_time=time.time()
response["source_documents"] = documents
response["start_time"] = start_time
response["end_preprocessing_time"] = end_preprocessing_time
response["end_llm_time"] = end_llm_time
return response
def analyze_times(answer, start_time, end_preprocessing_time, end_llm_time, tokenizer):
preprocessing_time=end_preprocessing_time-start_time
llm_time=end_llm_time-end_preprocessing_time
token_count=len(tokenizer.encode(answer))
tokens_per_second = token_count / llm_time
perf = {"preprocessing_time": preprocessing_time,
"llm_time": llm_time,
"token_count": token_count,
"tokens_per_second": tokens_per_second}
return perf
def generate(qa_chain, question, tokenizer):
response = qa_chain.invoke({"question": question})
answer = response.get('answer')
sources = set([
f'{sd.metadata["filename"]}'
for sd in response["source_documents"]
])
times = analyze_times(
answer,
response.get("start_time"),
response.get("end_preprocessing_time"),
response.get("end_llm_time"),
tokenizer
)
return answer, sources, times
def process_bulk_QA(vectordb_path, questions_file_path):
documentRetrieval = DocumentRetrieval()
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
if os.path.exists(vectordb_path):
# load the vectorstore
embeddings = documentRetrieval.load_embedding_model()
vectorstore = documentRetrieval.load_vdb(vectordb_path, embeddings)
print("Database loaded")
documentRetrieval.init_retriever(vectorstore)
print("retriever initialized")
#get qa chain
qa_chain = TimedRetrievalQAChain(
retriever=documentRetrieval.retriever,
llm=documentRetrieval.llm,
qa_prompt = load_prompt(os.path.join(kit_dir, documentRetrieval.prompts["qa_prompt"])),
rerank = documentRetrieval.retrieval_info["rerank"],
final_k_retrieved_documents = documentRetrieval.retrieval_info["final_k_retrieved_documents"]
)
else:
raise f"vector db path {vectordb_path} does not exist"
if os.path.exists(questions_file_path):
df = pd.read_excel(questions_file_path)
print(df)
output_file_path = questions_file_path.replace('.xlsx', '_output.xlsx')
if 'Answer' not in df.columns:
df['Answer'] = ''
df['Sources'] = ''
df['preprocessing_time'] = ''
df['llm_time'] = ''
df['token_count'] = ''
df['tokens_per_second'] = ''
for index, row in df.iterrows():
if row['Answer'].strip()=='': # Only process if 'Answer' is empty
try:
# Generate the answer
print(f"Generating answer for row {index}")
answer, sources, times = generate(qa_chain, row['Questions'], tokenizer)
df.at[index, 'Answer'] = answer
df.at[index, 'Sources'] = sources
df.at[index, 'preprocessing_time'] = times.get("preprocessing_time")
df.at[index, 'llm_time'] = times.get("llm_time")
df.at[index, 'token_count'] = times.get("token_count")
df.at[index, 'tokens_per_second'] = times.get("tokens_per_second")
except Exception as e:
print(f"Error processing row {index}: {e}")
# Save the file after each iteration to avoid data loss
df.to_excel(output_file_path, index=False)
else:
print(f"Skipping row {index} because 'Answer' is already in the document")
return output_file_path
else:
raise f"questions file path {questions_file_path} does not exist"
if __name__ == "__main__":
# Parse the arguments
parser = argparse.ArgumentParser(description='use a vectordb and an excel file with questions in the first column and generate answers for all the questions')
parser.add_argument('vectordb_path', type=str, help='vector db path with stored documents for RAG')
parser.add_argument('questions_path', type=str, help='xlsx file containing questions in a column named Questions')
args = parser.parse_args()
# process in bulk
out_file = process_bulk_QA(args.vectordb_path, args.questions_path)
print(f"Finished, responses in: {out_file}")