File size: 5,511 Bytes
a6c26b1 a84e3d2 a6c26b1 a84e3d2 a6c26b1 a84e3d2 a6c26b1 a84e3d2 a6c26b1 a84e3d2 a6c26b1 a84e3d2 a6c26b1 a84e3d2 a6c26b1 a84e3d2 a6c26b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import sys
import logging
import yaml
import gradio as gr
import time
current_dir = os.path.dirname(os.path.abspath(__file__))
print(current_dir)
from src.document_retrieval import DocumentRetrieval
from utils.visual.env_utils import env_input_fields, initialize_env_variables, are_credentials_set, save_credentials
from utils.parsing.sambaparse import parse_doc_universal # added Petro
from utils.vectordb.vector_db import VectorDb
CONFIG_PATH = os.path.join(current_dir,'config.yaml')
PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir
logging.basicConfig(level=logging.INFO)
logging.info("Gradio app is running")
class ChatState:
def __init__(self):
self.conversation = None
self.chat_history = []
self.show_sources = True
self.sources_history = []
self.vectorstore = None
self.input_disabled = True
self.document_retrieval = None
chat_state = ChatState()
chat_state.document_retrieval = DocumentRetrieval()
def handle_userinput(user_question):
if user_question:
try:
response_time = time.time()
response = chat_state.conversation.invoke({"question": user_question})
response_time = time.time() - response_time
chat_state.chat_history.append((user_question, response["answer"]))
#sources = set([f'{sd.metadata["filename"]}' for sd in response["source_documents"]])
#sources_text = "\n".join([f"{i+1}. {source}" for i, source in enumerate(sources)])
#state.sources_history.append(sources_text)
return chat_state.chat_history, "" #, state.sources_history
except Exception as e:
return f"An error occurred: {str(e)}", "" #, state.sources_history
return chat_state.chat_history, "" #, state.sources_history
def process_documents(files, save_location=None):
try:
#for doc in files:
_, _, text_chunks = parse_doc_universal(doc=files)
print(text_chunks)
#text_chunks = chat_state.document_retrieval.parse_doc(files)
embeddings = chat_state.document_retrieval.load_embedding_model()
collection_name = 'ekr_default_collection' if not config['prod_mode'] else None
vectorstore = chat_state.document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
chat_state.vectorstore = vectorstore
chat_state.document_retrieval.init_retriever(vectorstore)
chat_state.conversation = chat_state.document_retrieval.get_qa_retrieval_chain()
chat_state.input_disabled = False
return "Complete! You can now ask questions."
except Exception as e:
return f"An error occurred while processing: {str(e)}"
def reset_conversation():
chat_state.chat_history = []
#chat_state.sources_history = []
return chat_state.chat_history, ""
def show_selection(model):
return f"You selected: {model}"
# Read config file
with open(CONFIG_PATH, 'r') as yaml_file:
config = yaml.safe_load(yaml_file)
prod_mode = config.get('prod_mode', False)
default_collection = 'ekr_default_collection'
# Load env variables
initialize_env_variables(prod_mode)
caution_text = """⚠️ Note: depending on the size of your document, this could take several minutes.
"""
with gr.Blocks() as demo:
#gr.Markdown("# SambaNova Analyst Assistant") # title
gr.Markdown("# Enterprise Knowledge Retriever",
elem_id="title")
gr.Markdown("Powered by LLama3.1-8B-Instruct on SambaNova Cloud. Get your API key [here](https://cloud.sambanova.ai/apis).")
api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability")
# Step 1: Add PDF file
gr.Markdown("## 1️⃣ Upload PDF")
docs = gr.File(label="Add PDF file (single)", file_types=["pdf"], file_count="single")
# Step 2: Process PDF file
gr.Markdown(("## 2️⃣ Process document and create vector store"))
db_btn = gr.Radio(["ChromaDB"], label="Vector store type", value = "ChromaDB", type="index", info="Choose your vector store")
setup_output = gr.Textbox(label="Processing status", visible=True, value="None")
process_btn = gr.Button("🔄 Process")
gr.Markdown(caution_text)
process_btn.click(process_documents, inputs=[docs], outputs=setup_output, concurrency_limit=10)
#process_save_btn.click(process_documents, inputs=[file_upload, save_location], outputs=setup_output)
#load_db_btn.click(load_existing_db, inputs=[db_path], outputs=setup_output)
# Step 3: Chat with your data
gr.Markdown("## 3️⃣ Chat with your document")
chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True)
msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...")
clear = gr.Button("Clear chat")
#show_sources = gr.Checkbox(label="Show sources", value=True)
sources_output = gr.Textbox(label="Sources", visible=False)
#msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, sources_output])
msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, msg])
clear.click(reset_conversation, outputs=[chatbot,msg])
#show_sources.change(lambda x: gr.update(visible=x), show_sources, sources_output)
if __name__ == "__main__":
demo.launch()
|