import os import sys import yaml import gradio as gr import uuid current_dir = os.path.dirname(os.path.abspath(__file__)) from src.document_retrieval import DocumentRetrieval from utils.visual.env_utils import env_input_fields, initialize_env_variables, are_credentials_set, save_credentials from utils.parsing.sambaparse import parse_doc_universal # added from utils.vectordb.vector_db import VectorDb CONFIG_PATH = os.path.join(current_dir,'config.yaml') PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir def handle_userinput(user_question, conversation_chain, history): if user_question: try: # Generate response response = conversation_chain.invoke({"question": user_question}) # Append user message and response to chat history history = history + [(user_question, response["answer"])] return history, "" except Exception as e: error_msg = f"An error occurred: {str(e)}" history = history + [(user_question, error_msg)] return history, "" else: return history, "" def process_documents(files, collection_name, document_retrieval, vectorstore, conversation_chain, save_location=None): try: document_retrieval = DocumentRetrieval() _, _, text_chunks = parse_doc_universal(doc=files) print(len(text_chunks)) print(text_chunks) embeddings = document_retrieval.load_embedding_model() collection_id = str(uuid.uuid4()) collection_name = f"collection_{collection_id}" vectorstore = document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name) document_retrieval.init_retriever(vectorstore) conversation_chain = document_retrieval.get_qa_retrieval_chain() #input_disabled = False return conversation_chain, vectorstore, document_retrieval, collection_name, "Complete! You can now ask questions." except Exception as e: return conversation_chain, vectorstore, document_retrieval, collection_name, f"An error occurred while processing: {str(e)}" # Read config file with open(CONFIG_PATH, 'r') as yaml_file: config = yaml.safe_load(yaml_file) prod_mode = config.get('prod_mode', False) # Load env variables initialize_env_variables(prod_mode) caution_text = """⚠️ Note: depending on the size of your document, this could take several minutes. """ with gr.Blocks() as demo: vectorstore = gr.State() conversation_chain = gr.State() document_retrieval = gr.State() collection_name=gr.State() gr.Markdown("# Enterprise Knowledge Retriever", elem_id="title") gr.Markdown("Powered by LLama3.1-8B-Instruct on SambaNova Cloud. Get your API key [here](https://cloud.sambanova.ai/apis).") #api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability") # Step 1: Add PDF file gr.Markdown("## 1️⃣ Upload PDF") docs = gr.File(label="Add PDF file (single)", file_types=["pdf"], file_count="single") # Step 2: Process PDF file gr.Markdown(("## 2️⃣ Process document and create vector store")) db_btn = gr.Radio(["ChromaDB"], label="Vector store type", value = "ChromaDB", type="index", info="Choose your vector store") setup_output = gr.Textbox(label="Processing status", visible=True, value="None") process_btn = gr.Button("🔄 Process") gr.Markdown(caution_text) # Preprocessing events process_btn.click(process_documents, inputs=[docs, collection_name, document_retrieval, vectorstore, conversation_chain], outputs=[conversation_chain, vectorstore, document_retrieval, collection_name, setup_output], concurrency_limit=20) # Step 3: Chat with your data gr.Markdown("## 3️⃣ Chat with your document") chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True) msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...") clear_btn = gr.Button("Clear chat") sources_output = gr.Textbox(label="Sources", visible=False) # Chatbot events msg.submit(handle_userinput, inputs=[msg, conversation_chain, chatbot], outputs=[chatbot, msg], queue=False) clear_btn.click(lambda: [None, ""], inputs=None, outputs=[chatbot, msg], queue=False) if __name__ == "__main__": demo.launch()