Update app.py
Browse files
app.py
CHANGED
@@ -4,16 +4,16 @@ import re
|
|
4 |
from pathlib import Path
|
5 |
from statistics import median
|
6 |
|
|
|
7 |
import streamlit as st
|
8 |
from bs4 import BeautifulSoup
|
|
|
9 |
from langchain.chains import ConversationalRetrievalChain
|
10 |
from langchain.docstore.document import Document
|
11 |
from langchain.document_loaders import PDFMinerPDFasHTMLLoader, WebBaseLoader
|
12 |
-
from langchain_openai import ChatOpenAI
|
13 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
|
|
14 |
from ragatouille import RAGPretrainedModel
|
15 |
-
import pandas as pd
|
16 |
-
|
17 |
|
18 |
st.set_page_config(layout="wide")
|
19 |
os.environ["OPENAI_API_KEY"] = "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"
|
@@ -46,17 +46,21 @@ def query_llm(retriever, query):
|
|
46 |
chain_type="stuff",
|
47 |
)
|
48 |
relevant_docs = retriever.get_relevant_documents(query)
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
result = result["answer"]
|
51 |
st.session_state.messages.append((query, result))
|
52 |
-
return relevant_docs, result
|
53 |
|
54 |
|
55 |
def input_fields():
|
56 |
st.session_state.source_doc_urls = [
|
57 |
url.strip()
|
58 |
for url in st.sidebar.text_area(
|
59 |
-
"Source Document URLs\n(New line separated)", height=
|
60 |
).split("\n")
|
61 |
]
|
62 |
|
@@ -201,25 +205,34 @@ def boot():
|
|
201 |
st.title("Xi Chatbot")
|
202 |
st.sidebar.title("Input Documents")
|
203 |
input_fields()
|
204 |
-
col1, col2 = st.columns([4, 1])
|
205 |
st.sidebar.button("Submit Documents", on_click=process_documents)
|
206 |
if "headers" in st.session_state:
|
207 |
-
|
208 |
-
|
|
|
|
|
209 |
if "messages" not in st.session_state:
|
210 |
st.session_state.messages = []
|
211 |
for message in st.session_state.messages:
|
212 |
-
|
213 |
-
|
214 |
-
if query :=
|
215 |
-
|
216 |
-
references, response = query_llm(st.session_state.retriever, query)
|
217 |
sorted_references = sorted([ref.metadata["chunk_id"] for ref in references])
|
218 |
references_str = " ".join([f"[{ref}]" for ref in sorted_references])
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
221 |
)
|
|
|
|
|
|
|
222 |
|
223 |
|
224 |
if __name__ == "__main__":
|
225 |
-
boot()
|
|
|
4 |
from pathlib import Path
|
5 |
from statistics import median
|
6 |
|
7 |
+
import pandas as pd
|
8 |
import streamlit as st
|
9 |
from bs4 import BeautifulSoup
|
10 |
+
from langchain.callbacks import get_openai_callback
|
11 |
from langchain.chains import ConversationalRetrievalChain
|
12 |
from langchain.docstore.document import Document
|
13 |
from langchain.document_loaders import PDFMinerPDFasHTMLLoader, WebBaseLoader
|
|
|
14 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
15 |
+
from langchain_openai import ChatOpenAI
|
16 |
from ragatouille import RAGPretrainedModel
|
|
|
|
|
17 |
|
18 |
st.set_page_config(layout="wide")
|
19 |
os.environ["OPENAI_API_KEY"] = "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"
|
|
|
46 |
chain_type="stuff",
|
47 |
)
|
48 |
relevant_docs = retriever.get_relevant_documents(query)
|
49 |
+
with get_openai_callback() as cb:
|
50 |
+
result = qa_chain(
|
51 |
+
{"question": query, "chat_history": st.session_state.messages}
|
52 |
+
)
|
53 |
+
stats = cb
|
54 |
result = result["answer"]
|
55 |
st.session_state.messages.append((query, result))
|
56 |
+
return relevant_docs, result, stats
|
57 |
|
58 |
|
59 |
def input_fields():
|
60 |
st.session_state.source_doc_urls = [
|
61 |
url.strip()
|
62 |
for url in st.sidebar.text_area(
|
63 |
+
"Source Document URLs\n(New line separated)", height=50
|
64 |
).split("\n")
|
65 |
]
|
66 |
|
|
|
205 |
st.title("Xi Chatbot")
|
206 |
st.sidebar.title("Input Documents")
|
207 |
input_fields()
|
|
|
208 |
st.sidebar.button("Submit Documents", on_click=process_documents)
|
209 |
if "headers" in st.session_state:
|
210 |
+
st.sidebar.write("### References")
|
211 |
+
st.sidebar.write(st.session_state.headers)
|
212 |
+
if "costing" not in st.session_state:
|
213 |
+
st.session_state.costing = []
|
214 |
if "messages" not in st.session_state:
|
215 |
st.session_state.messages = []
|
216 |
for message in st.session_state.messages:
|
217 |
+
st.chat_message("human").write(message[0])
|
218 |
+
st.chat_message("ai").write(message[1])
|
219 |
+
if query := st.chat_input():
|
220 |
+
st.chat_message("human").write(query)
|
221 |
+
references, response, stats = query_llm(st.session_state.retriever, query)
|
222 |
sorted_references = sorted([ref.metadata["chunk_id"] for ref in references])
|
223 |
references_str = " ".join([f"[{ref}]" for ref in sorted_references])
|
224 |
+
st.chat_message("ai").write(response + "\n\n---\nReferences:" + references_str)
|
225 |
+
st.session_state.costing.append(
|
226 |
+
{
|
227 |
+
"prompt tokens": stats.prompt_tokens,
|
228 |
+
"completion tokens": stats.completion_tokens,
|
229 |
+
"total cost": stats.total_cost,
|
230 |
+
}
|
231 |
)
|
232 |
+
stats_df = pd.DataFrame(st.session_state.costing)
|
233 |
+
stats_df.loc["total"] = stats_df.sum()
|
234 |
+
st.sidebar.write(stats_df)
|
235 |
|
236 |
|
237 |
if __name__ == "__main__":
|
238 |
+
boot()
|