|
import os |
|
import sys |
|
import logging |
|
import yaml |
|
import gradio as gr |
|
import time |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
print(current_dir) |
|
|
|
from src.document_retrieval import DocumentRetrieval |
|
from utils.visual.env_utils import env_input_fields, initialize_env_variables, are_credentials_set, save_credentials |
|
from utils.parsing.sambaparse import parse_doc_universal |
|
from utils.vectordb.vector_db import VectorDb |
|
|
|
CONFIG_PATH = os.path.join(current_dir,'config.yaml') |
|
PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logging.info("Gradio app is running") |
|
|
|
class ChatState: |
|
def __init__(self): |
|
self.conversation = None |
|
self.chat_history = [] |
|
self.show_sources = True |
|
self.sources_history = [] |
|
self.vectorstore = None |
|
self.input_disabled = True |
|
self.document_retrieval = None |
|
|
|
chat_state = ChatState() |
|
|
|
chat_state.document_retrieval = DocumentRetrieval() |
|
|
|
def handle_userinput(user_question): |
|
if user_question: |
|
try: |
|
response_time = time.time() |
|
response = chat_state.conversation.invoke({"question": user_question}) |
|
response_time = time.time() - response_time |
|
chat_state.chat_history.append((user_question, response["answer"])) |
|
|
|
|
|
|
|
|
|
|
|
return chat_state.chat_history, "" |
|
except Exception as e: |
|
return f"An error occurred: {str(e)}", "" |
|
return chat_state.chat_history, "" |
|
|
|
def process_documents(files, save_location=None): |
|
try: |
|
|
|
_, _, text_chunks = parse_doc_universal(doc=files) |
|
print(text_chunks) |
|
|
|
embeddings = chat_state.document_retrieval.load_embedding_model() |
|
collection_name = 'ekr_default_collection' if not config['prod_mode'] else None |
|
vectorstore = chat_state.document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name) |
|
chat_state.vectorstore = vectorstore |
|
chat_state.document_retrieval.init_retriever(vectorstore) |
|
chat_state.conversation = chat_state.document_retrieval.get_qa_retrieval_chain() |
|
chat_state.input_disabled = False |
|
return "Complete! You can now ask questions." |
|
except Exception as e: |
|
return f"An error occurred while processing: {str(e)}" |
|
|
|
def reset_conversation(): |
|
chat_state.chat_history = [] |
|
|
|
return chat_state.chat_history, "" |
|
|
|
def show_selection(model): |
|
return f"You selected: {model}" |
|
|
|
|
|
with open(CONFIG_PATH, 'r') as yaml_file: |
|
config = yaml.safe_load(yaml_file) |
|
|
|
prod_mode = config.get('prod_mode', False) |
|
default_collection = 'ekr_default_collection' |
|
|
|
|
|
initialize_env_variables(prod_mode) |
|
|
|
caution_text = """⚠️ Note: depending on the size of your document, this could take several minutes. |
|
""" |
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown("# Enterprise Knowledge Retriever", |
|
elem_id="title") |
|
|
|
gr.Markdown("Powered by LLama3.1-8B-Instruct on SambaNova Cloud. Get your API key [here](https://cloud.sambanova.ai/apis).") |
|
|
|
api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability") |
|
|
|
|
|
gr.Markdown("## 1️⃣ Upload PDF") |
|
docs = gr.File(label="Add PDF file (single)", file_types=["pdf"], file_count="single") |
|
|
|
|
|
gr.Markdown(("## 2️⃣ Process document and create vector store")) |
|
db_btn = gr.Radio(["ChromaDB"], label="Vector store type", value = "ChromaDB", type="index", info="Choose your vector store") |
|
setup_output = gr.Textbox(label="Processing status", visible=True, value="None") |
|
process_btn = gr.Button("🔄 Process") |
|
gr.Markdown(caution_text) |
|
|
|
|
|
process_btn.click(process_documents, inputs=[docs], outputs=setup_output, concurrency_limit=10) |
|
|
|
|
|
|
|
|
|
gr.Markdown("## 3️⃣ Chat with your document") |
|
chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True) |
|
msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...") |
|
clear = gr.Button("Clear chat") |
|
|
|
sources_output = gr.Textbox(label="Sources", visible=False) |
|
|
|
|
|
msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, msg]) |
|
clear.click(reset_conversation, outputs=[chatbot,msg]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|