import streamlit as st import os import pandas as pd from command_center import CommandCenter from process_documents import process_documents from embed_documents import create_retriever import json from langchain.callbacks import get_openai_callback from langchain.chains import ConversationalRetrievalChain from langchain_openai import ChatOpenAI st.set_page_config(layout="wide") os.environ["OPENAI_API_KEY"] = "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" get_references = lambda relevant_docs: " ".join( [f"[{ref}]" for ref in sorted([ref.metadata["chunk_id"] for ref in relevant_docs])] ) session_state_2_llm_chat_history = lambda session_state: [ ss[:2] for ss in session_state if not ss[0].startswith("/") ] ai_message_format = lambda message, references: ( f"{message}\n\n---\n\n{references}" if references != "" else message ) def process_documents_wrapper(inputs): snippets = process_documents(inputs) st.session_state.retriever = create_retriever(snippets) st.session_state.source_doc_urls = inputs st.session_state.index = [snip.metadata["header"] for snip in snippets] response = f"Uploaded and processed documents {inputs}" st.session_state.messages.append((f"/upload {inputs}", response, "")) return response def index_documents_wrapper(inputs=None): response = pd.Series(st.session_state.index, name="references").to_markdown() st.session_state.messages.append(("/index", response, "")) return response def calculate_cost_wrapper(inputs=None): try: stats_df = pd.DataFrame(st.session_state.costing) stats_df.loc["total"] = stats_df.sum() response = stats_df.to_markdown() except ValueError: response = "No costing incurred yet" st.session_state.messages.append(("/cost", response, "")) return response def download_conversation_wrapper(inputs=None): conversation_data = json.dumps( { "document_urls": ( st.session_state.source_doc_urls if "source_doc_urls" in st.session_state else [] ), "document_snippets": ( st.session_state.index.to_list() if "headers" in st.session_state else [] ), "conversation": [ {"human": message[0], "ai": message[1], "references": message[2]} for message in st.session_state.messages ], "costing": ( st.session_state.costing if "costing" in st.session_state else [] ), "total_cost": ( { k: sum(d[k] for d in st.session_state.costing) for k in st.session_state.costing[0] } if "costing" in st.session_state and len(st.session_state.costing) > 0 else {} ), } ) st.sidebar.download_button( "Download Conversation", conversation_data, file_name="conversation_data.json", mime="application/json", ) st.session_state.messages.append(("/download", "Conversation data downloaded", "")) def query_llm_wrapper(inputs): retriever = st.session_state.retriever qa_chain = ConversationalRetrievalChain.from_llm( llm=ChatOpenAI(model="gpt-4-0125-preview", temperature=0), retriever=retriever, return_source_documents=True, chain_type="stuff", ) relevant_docs = retriever.get_relevant_documents(inputs) with get_openai_callback() as cb: result = qa_chain( { "question": inputs, "chat_history": session_state_2_llm_chat_history( st.session_state.messages ), } ) stats = cb result = result["answer"] references = get_references(relevant_docs) st.session_state.messages.append((inputs, result, references)) st.session_state.costing.append( { "prompt tokens": stats.prompt_tokens, "completion tokens": stats.completion_tokens, "cost": stats.total_cost, } ) return result, references def boot(command_center): st.title("Agent Xi - An ArXiv Chatbot") if "costing" not in st.session_state: st.session_state.costing = [] if "messages" not in st.session_state: st.session_state.messages = [] for message in st.session_state.messages: st.chat_message("human").write(message[0]) st.chat_message("ai").write(ai_message_format(message[1], message[2])) if query := st.chat_input(): st.chat_message("human").write(query) response = command_center.execute_command(query) if response is None: pass elif type(response) == tuple: result, references = response st.chat_message("ai").write(ai_message_format(result, references)) else: st.chat_message("ai").write(response) if __name__ == "__main__": all_commands = [ ("/upload", list, process_documents_wrapper, "Upload and process documents"), ("/index", None, index_documents_wrapper, "View index of processed documents"), ("/cost", None, calculate_cost_wrapper, "Calculate cost of conversation"), ( "/download", None, download_conversation_wrapper, "Download conversation data", ), ] st.sidebar.title("Commands Menu") st.sidebar.write( pd.DataFrame( { "Command": [command[0] for command in all_commands], "Description": [command[3] for command in all_commands], } ) ) command_center = CommandCenter( default_input_type=str, default_function=query_llm_wrapper, all_commands=[command[:3] for command in all_commands], ) boot(command_center)