File size: 4,063 Bytes
a6c26b1
 
 
 
7801fa3
a6c26b1
 
 
 
3b758aa
a6c26b1
 
3b758aa
a6c26b1
 
e443083
3b758aa
a6c26b1
e443083
 
a6c26b1
e443083
a6c26b1
7801fa3
 
 
5ab5b15
e443083
a6c26b1
3b758aa
a6c26b1
aa94ed8
a6c26b1
7801fa3
883864f
5ab5b15
7801fa3
 
5ab5b15
 
3b758aa
 
a6c26b1
3b758aa
a6c26b1
 
 
 
 
aa94ed8
3b758aa
aa94ed8
3b758aa
5ab5b15
a84e3d2
a6c26b1
 
a84e3d2
a6c26b1
e18dfac
a6c26b1
 
a84e3d2
 
a6c26b1
 
a84e3d2
 
 
a6c26b1
 
a84e3d2
5ab5b15
7801fa3
a6c26b1
 
a84e3d2
a6c26b1
a84e3d2
5ab5b15
a6c26b1
 
5ab5b15
3b758aa
aa94ed8
a6c26b1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import sys
import yaml
import gradio as gr
import uuid

current_dir = os.path.dirname(os.path.abspath(__file__))

from src.document_retrieval import DocumentRetrieval
from utils.parsing.sambaparse import parse_doc_universal # added
from utils.vectordb.vector_db import VectorDb

def handle_userinput(user_question, conversation_chain, history):
    if user_question:
        try:
            # Generate response
            response = conversation_chain.invoke({"question": user_question})

            # Append user message and response to chat history
            history = history + [(user_question, response["answer"])]

            return history, "" 
        except Exception as e:
            error_msg = f"An error occurred: {str(e)}"
            history = history + [(user_question, error_msg)]
            return history, "" 
    else:
        return history, ""

def process_documents(files, collection_name, document_retrieval, vectorstore, conversation_chain, save_location=None):
    try:
        document_retrieval = DocumentRetrieval()
        _, _, text_chunks = parse_doc_universal(doc=files)
        print(len(text_chunks))
        print(text_chunks[0])
        embeddings = document_retrieval.load_embedding_model()
        collection_id = str(uuid.uuid4())
        collection_name = f"collection_{collection_id}"
        vectorstore = document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
        document_retrieval.init_retriever(vectorstore)
        conversation_chain = document_retrieval.get_qa_retrieval_chain()
        return conversation_chain, vectorstore, document_retrieval, collection_name, "Complete! You can now ask questions."
    except Exception as e:
        return conversation_chain, vectorstore, document_retrieval, collection_name, f"An error occurred while processing: {str(e)}"

caution_text = """⚠️ Note: depending on the size of your document, this could take several minutes.
"""

with gr.Blocks() as demo:
    vectorstore = gr.State()
    conversation_chain = gr.State()
    document_retrieval = gr.State()
    collection_name=gr.State()

    gr.Markdown("# Enterprise Knowledge Retriever", 
            elem_id="title")
    
    gr.Markdown("Powered by LLama3.1-8B-Instruct on SambaNova Cloud. Get your API key [here](https://cloud.sambanova.ai/apis).")

    api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability")

    # Step 1: Add PDF file   
    gr.Markdown("## 1️⃣ Upload PDF")
    docs = gr.File(label="Add PDF file (single)", file_types=["pdf"], file_count="single")
        
    # Step 2: Process PDF file
    gr.Markdown(("## 2️⃣ Process document and create vector store"))
    db_btn = gr.Radio(["ChromaDB"], label="Vector store type", value = "ChromaDB", type="index", info="Choose your vector store")
    setup_output = gr.Textbox(label="Processing status", visible=True, value="None") 
    process_btn = gr.Button("🔄 Process")
    gr.Markdown(caution_text)
      
    # Preprocessing events
    process_btn.click(process_documents, inputs=[docs, collection_name, document_retrieval, vectorstore, conversation_chain], outputs=[conversation_chain, vectorstore, document_retrieval, collection_name, setup_output], concurrency_limit=20)
    
    # Step 3: Chat with your data
    gr.Markdown("## 3️⃣ Chat with your document")
    chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True)
    msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...")
    clear_btn = gr.Button("Clear chat")
    sources_output = gr.Textbox(label="Sources", visible=False)

    # Chatbot events
    msg.submit(handle_userinput, inputs=[msg, conversation_chain, chatbot], outputs=[chatbot, msg], queue=False)
    clear_btn.click(lambda: [None, ""], inputs=None, outputs=[chatbot, msg], queue=False)

if __name__ == "__main__":
    demo.launch()