Ritvik19 commited on
Commit
d2b4a56
1 Parent(s): 89588e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -17
app.py CHANGED
@@ -4,16 +4,16 @@ import re
4
  from pathlib import Path
5
  from statistics import median
6
 
 
7
  import streamlit as st
8
  from bs4 import BeautifulSoup
 
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain.docstore.document import Document
11
  from langchain.document_loaders import PDFMinerPDFasHTMLLoader, WebBaseLoader
12
- from langchain_openai import ChatOpenAI
13
  from langchain.retrievers.multi_query import MultiQueryRetriever
 
14
  from ragatouille import RAGPretrainedModel
15
- import pandas as pd
16
-
17
 
18
  st.set_page_config(layout="wide")
19
  os.environ["OPENAI_API_KEY"] = "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"
@@ -46,17 +46,21 @@ def query_llm(retriever, query):
46
  chain_type="stuff",
47
  )
48
  relevant_docs = retriever.get_relevant_documents(query)
49
- result = qa_chain({"question": query, "chat_history": st.session_state.messages})
 
 
 
 
50
  result = result["answer"]
51
  st.session_state.messages.append((query, result))
52
- return relevant_docs, result
53
 
54
 
55
  def input_fields():
56
  st.session_state.source_doc_urls = [
57
  url.strip()
58
  for url in st.sidebar.text_area(
59
- "Source Document URLs\n(New line separated)", height=200
60
  ).split("\n")
61
  ]
62
 
@@ -201,25 +205,34 @@ def boot():
201
  st.title("Xi Chatbot")
202
  st.sidebar.title("Input Documents")
203
  input_fields()
204
- col1, col2 = st.columns([4, 1])
205
  st.sidebar.button("Submit Documents", on_click=process_documents)
206
  if "headers" in st.session_state:
207
- col2.write("### References")
208
- col2.write(st.session_state.headers)
 
 
209
  if "messages" not in st.session_state:
210
  st.session_state.messages = []
211
  for message in st.session_state.messages:
212
- col1.chat_message("human").write(message[0])
213
- col1.chat_message("ai").write(message[1])
214
- if query := col1.chat_input():
215
- col1.chat_message("human").write(query)
216
- references, response = query_llm(st.session_state.retriever, query)
217
  sorted_references = sorted([ref.metadata["chunk_id"] for ref in references])
218
  references_str = " ".join([f"[{ref}]" for ref in sorted_references])
219
- col1.chat_message("ai").write(
220
- response + "\n\n---\nReferences:" + references_str
 
 
 
 
 
221
  )
 
 
 
222
 
223
 
224
  if __name__ == "__main__":
225
- boot()
 
4
  from pathlib import Path
5
  from statistics import median
6
 
7
+ import pandas as pd
8
  import streamlit as st
9
  from bs4 import BeautifulSoup
10
+ from langchain.callbacks import get_openai_callback
11
  from langchain.chains import ConversationalRetrievalChain
12
  from langchain.docstore.document import Document
13
  from langchain.document_loaders import PDFMinerPDFasHTMLLoader, WebBaseLoader
 
14
  from langchain.retrievers.multi_query import MultiQueryRetriever
15
+ from langchain_openai import ChatOpenAI
16
  from ragatouille import RAGPretrainedModel
 
 
17
 
18
  st.set_page_config(layout="wide")
19
  os.environ["OPENAI_API_KEY"] = "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"
 
46
  chain_type="stuff",
47
  )
48
  relevant_docs = retriever.get_relevant_documents(query)
49
+ with get_openai_callback() as cb:
50
+ result = qa_chain(
51
+ {"question": query, "chat_history": st.session_state.messages}
52
+ )
53
+ stats = cb
54
  result = result["answer"]
55
  st.session_state.messages.append((query, result))
56
+ return relevant_docs, result, stats
57
 
58
 
59
  def input_fields():
60
  st.session_state.source_doc_urls = [
61
  url.strip()
62
  for url in st.sidebar.text_area(
63
+ "Source Document URLs\n(New line separated)", height=50
64
  ).split("\n")
65
  ]
66
 
 
205
  st.title("Xi Chatbot")
206
  st.sidebar.title("Input Documents")
207
  input_fields()
 
208
  st.sidebar.button("Submit Documents", on_click=process_documents)
209
  if "headers" in st.session_state:
210
+ st.sidebar.write("### References")
211
+ st.sidebar.write(st.session_state.headers)
212
+ if "costing" not in st.session_state:
213
+ st.session_state.costing = []
214
  if "messages" not in st.session_state:
215
  st.session_state.messages = []
216
  for message in st.session_state.messages:
217
+ st.chat_message("human").write(message[0])
218
+ st.chat_message("ai").write(message[1])
219
+ if query := st.chat_input():
220
+ st.chat_message("human").write(query)
221
+ references, response, stats = query_llm(st.session_state.retriever, query)
222
  sorted_references = sorted([ref.metadata["chunk_id"] for ref in references])
223
  references_str = " ".join([f"[{ref}]" for ref in sorted_references])
224
+ st.chat_message("ai").write(response + "\n\n---\nReferences:" + references_str)
225
+ st.session_state.costing.append(
226
+ {
227
+ "prompt tokens": stats.prompt_tokens,
228
+ "completion tokens": stats.completion_tokens,
229
+ "total cost": stats.total_cost,
230
+ }
231
  )
232
+ stats_df = pd.DataFrame(st.session_state.costing)
233
+ stats_df.loc["total"] = stats_df.sum()
234
+ st.sidebar.write(stats_df)
235
 
236
 
237
  if __name__ == "__main__":
238
+ boot()