petrojm commited on
Commit
5ab5b15
1 Parent(s): a84e3d2

changes to app.py and document_retrieval.py

Browse files
Files changed (2) hide show
  1. app.py +40 -40
  2. src/document_retrieval.py +11 -9
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
  import sys
3
- import logging
4
  import yaml
5
  import gradio as gr
6
- import time
7
 
8
  current_dir = os.path.dirname(os.path.abspath(__file__))
9
  print(current_dir)
@@ -16,61 +14,61 @@ from utils.vectordb.vector_db import VectorDb
16
  CONFIG_PATH = os.path.join(current_dir,'config.yaml')
17
  PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir
18
 
19
- logging.basicConfig(level=logging.INFO)
20
- logging.info("Gradio app is running")
 
 
 
 
 
 
 
21
 
22
- class ChatState:
23
- def __init__(self):
24
- self.conversation = None
25
- self.chat_history = []
26
- self.show_sources = True
27
- self.sources_history = []
28
- self.vectorstore = None
29
- self.input_disabled = True
30
- self.document_retrieval = None
31
 
32
- chat_state = ChatState()
33
 
34
- chat_state.document_retrieval = DocumentRetrieval()
35
-
36
- def handle_userinput(user_question):
37
  if user_question:
38
  try:
39
- response_time = time.time()
40
- response = chat_state.conversation.invoke({"question": user_question})
41
- response_time = time.time() - response_time
42
- chat_state.chat_history.append((user_question, response["answer"]))
43
 
44
  #sources = set([f'{sd.metadata["filename"]}' for sd in response["source_documents"]])
45
  #sources_text = "\n".join([f"{i+1}. {source}" for i, source in enumerate(sources)])
46
  #state.sources_history.append(sources_text)
47
 
48
- return chat_state.chat_history, "" #, state.sources_history
49
  except Exception as e:
50
  return f"An error occurred: {str(e)}", "" #, state.sources_history
51
- return chat_state.chat_history, "" #, state.sources_history
 
 
52
 
53
- def process_documents(files, save_location=None):
54
  try:
55
  #for doc in files:
56
  _, _, text_chunks = parse_doc_universal(doc=files)
57
  print(text_chunks)
58
  #text_chunks = chat_state.document_retrieval.parse_doc(files)
59
- embeddings = chat_state.document_retrieval.load_embedding_model()
60
  collection_name = 'ekr_default_collection' if not config['prod_mode'] else None
61
- vectorstore = chat_state.document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
62
- chat_state.vectorstore = vectorstore
63
- chat_state.document_retrieval.init_retriever(vectorstore)
64
- chat_state.conversation = chat_state.document_retrieval.get_qa_retrieval_chain()
65
- chat_state.input_disabled = False
66
- return "Complete! You can now ask questions."
67
  except Exception as e:
68
- return f"An error occurred while processing: {str(e)}"
69
 
70
  def reset_conversation():
71
- chat_state.chat_history = []
72
  #chat_state.sources_history = []
73
- return chat_state.chat_history, ""
74
 
75
  def show_selection(model):
76
  return f"You selected: {model}"
@@ -89,7 +87,8 @@ caution_text = """⚠️ Note: depending on the size of your document, this coul
89
  """
90
 
91
  with gr.Blocks() as demo:
92
- #gr.Markdown("# SambaNova Analyst Assistant") # title
 
93
  gr.Markdown("# Enterprise Knowledge Retriever",
94
  elem_id="title")
95
 
@@ -108,8 +107,8 @@ with gr.Blocks() as demo:
108
  process_btn = gr.Button("🔄 Process")
109
  gr.Markdown(caution_text)
110
 
111
-
112
- process_btn.click(process_documents, inputs=[docs], outputs=setup_output, concurrency_limit=10)
113
  #process_save_btn.click(process_documents, inputs=[file_upload, save_location], outputs=setup_output)
114
  #load_db_btn.click(load_existing_db, inputs=[db_path], outputs=setup_output)
115
 
@@ -117,13 +116,14 @@ with gr.Blocks() as demo:
117
  gr.Markdown("## 3️⃣ Chat with your document")
118
  chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True)
119
  msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...")
120
- clear = gr.Button("Clear chat")
121
  #show_sources = gr.Checkbox(label="Show sources", value=True)
122
  sources_output = gr.Textbox(label="Sources", visible=False)
123
 
 
124
  #msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, sources_output])
125
- msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, msg])
126
- clear.click(reset_conversation, outputs=[chatbot,msg])
127
  #show_sources.change(lambda x: gr.update(visible=x), show_sources, sources_output)
128
 
129
  if __name__ == "__main__":
 
1
  import os
2
  import sys
 
3
  import yaml
4
  import gradio as gr
 
5
 
6
  current_dir = os.path.dirname(os.path.abspath(__file__))
7
  print(current_dir)
 
14
  CONFIG_PATH = os.path.join(current_dir,'config.yaml')
15
  PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir
16
 
17
+ #class ChatState:
18
+ # def __init__(self):
19
+ # self.conversation = None
20
+ # self.chat_history = []
21
+ # self.show_sources = True
22
+ # self.sources_history = []
23
+ # self.vectorstore = None
24
+ # self.input_disabled = True
25
+ # self.document_retrieval = None
26
 
27
+ chat_history = gr.State()
28
+ chat_history = []
29
+ vectorstore = gr.State()
30
+ document_retrieval = gr.State()
 
 
 
 
 
31
 
32
+ document_retrieval = DocumentRetrieval()
33
 
34
+ def handle_userinput(user_question, conversation):
 
 
35
  if user_question:
36
  try:
37
+ response = conversation.invoke({"question": user_question})
38
+ chat_history.append((user_question, response["answer"]))
 
 
39
 
40
  #sources = set([f'{sd.metadata["filename"]}' for sd in response["source_documents"]])
41
  #sources_text = "\n".join([f"{i+1}. {source}" for i, source in enumerate(sources)])
42
  #state.sources_history.append(sources_text)
43
 
44
+ return chat_history, "" #, state.sources_history
45
  except Exception as e:
46
  return f"An error occurred: {str(e)}", "" #, state.sources_history
47
+ else:
48
+ return "An error occurred", ""
49
+ #return chat_history, "" #, state.sources_history
50
 
51
+ def process_documents(files, conversation, save_location=None):
52
  try:
53
  #for doc in files:
54
  _, _, text_chunks = parse_doc_universal(doc=files)
55
  print(text_chunks)
56
  #text_chunks = chat_state.document_retrieval.parse_doc(files)
57
+ embeddings = document_retrieval.load_embedding_model()
58
  collection_name = 'ekr_default_collection' if not config['prod_mode'] else None
59
+ vectorstore = document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
60
+ #vectorstore = vectorstore
61
+ document_retrieval.init_retriever(vectorstore)
62
+ conversation = document_retrieval.get_qa_retrieval_chain()
63
+ #input_disabled = False
64
+ return conversation, "Complete! You can now ask questions."
65
  except Exception as e:
66
+ return conversation, f"An error occurred while processing: {str(e)}"
67
 
68
  def reset_conversation():
69
+ chat_history = []
70
  #chat_state.sources_history = []
71
+ return chat_history, ""
72
 
73
  def show_selection(model):
74
  return f"You selected: {model}"
 
87
  """
88
 
89
  with gr.Blocks() as demo:
90
+ conversation = gr.State()
91
+
92
  gr.Markdown("# Enterprise Knowledge Retriever",
93
  elem_id="title")
94
 
 
107
  process_btn = gr.Button("🔄 Process")
108
  gr.Markdown(caution_text)
109
 
110
+ # Preprocessing events
111
+ process_btn.click(process_documents, inputs=[docs, conversation], outputs=[conversation, setup_output])
112
  #process_save_btn.click(process_documents, inputs=[file_upload, save_location], outputs=setup_output)
113
  #load_db_btn.click(load_existing_db, inputs=[db_path], outputs=setup_output)
114
 
 
116
  gr.Markdown("## 3️⃣ Chat with your document")
117
  chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True)
118
  msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...")
119
+ clear_btn = gr.Button("Clear chat")
120
  #show_sources = gr.Checkbox(label="Show sources", value=True)
121
  sources_output = gr.Textbox(label="Sources", visible=False)
122
 
123
+ # Chatbot events
124
  #msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, sources_output])
125
+ msg.submit(handle_userinput, inputs=[msg, conversation], outputs=[chatbot, msg])
126
+ clear_btn.click(reset_conversation, outputs=[chatbot,msg])
127
  #show_sources.change(lambda x: gr.update(visible=x), show_sources, sources_output)
128
 
129
  if __name__ == "__main__":
src/document_retrieval.py CHANGED
@@ -21,7 +21,7 @@ repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
21
  sys.path.append(kit_dir)
22
  sys.path.append(repo_dir)
23
 
24
- import streamlit as st
25
 
26
  from utils.model_wrappers.api_gateway import APIGateway
27
  from utils.vectordb.vector_db import VectorDb
@@ -30,7 +30,7 @@ from utils.visual.env_utils import get_wandb_key
30
  CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
31
  PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db')
32
 
33
- load_dotenv(os.path.join(kit_dir, '.env'))
34
 
35
 
36
  from utils.parsing.sambaparse import parse_doc_universal
@@ -153,13 +153,15 @@ class DocumentRetrieval:
153
  return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode
154
 
155
  def set_llm(self):
156
- if self.prod_mode:
157
- sambanova_api_key = st.session_state.SAMBANOVA_API_KEY
158
- else:
159
- if 'SAMBANOVA_API_KEY' in st.session_state:
160
- sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') or st.session_state.SAMBANOVA_API_KEY
161
- else:
162
- sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
 
 
163
 
164
  llm = APIGateway.load_llm(
165
  type=self.api_info,
 
21
  sys.path.append(kit_dir)
22
  sys.path.append(repo_dir)
23
 
24
+ #import streamlit as st
25
 
26
  from utils.model_wrappers.api_gateway import APIGateway
27
  from utils.vectordb.vector_db import VectorDb
 
30
  CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
31
  PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db')
32
 
33
+ #load_dotenv(os.path.join(kit_dir, '.env'))
34
 
35
 
36
  from utils.parsing.sambaparse import parse_doc_universal
 
153
  return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode
154
 
155
  def set_llm(self):
156
+ #if self.prod_mode:
157
+ # sambanova_api_key = st.session_state.SAMBANOVA_API_KEY
158
+ #else:
159
+ # if 'SAMBANOVA_API_KEY' in st.session_state:
160
+ # sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') or st.session_state.SAMBANOVA_API_KEY
161
+ # else:
162
+ # sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
163
+
164
+ sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
165
 
166
  llm = APIGateway.load_llm(
167
  type=self.api_info,