import streamlit as st import os import pandas as pd from command_center import CommandCenter from process_documents import process_documents, num_tokens from embed_documents import create_retriever import json from langchain.callbacks import get_openai_callback from langchain_openai import ChatOpenAI import base64 from chat_chains import ( parse_model_response, qa_chain, format_docs, parse_context_and_question, ai_response_format, ) from autoqa_chain import auto_qa_chain from chain_of_density import chain_of_density_chain from insights_bullet_chain import insights_bullet_chain from insights_mind_map_chain import insights_mind_map_chain from synopsis_chain import synopsis_chain from custom_exceptions import InvalidArgumentError, InvalidCommandError from openai_configuration import openai_parser from summary_chain import summary_chain from tldr_chain import tldr_chain st.set_page_config(layout="wide") welcome_message = """ Hi I'm Agent Zeta, your AI assistant, dedicated to making your journey through machine learning research papers as insightful and interactive as possible. Whether you're diving into the latest studies or brushing up on foundational papers, I'm here to help navigate, discuss, and analyze content with you. Here's a quick guide to getting started with me: | Command | Description | |---------|-------------| | `/configure --key --model ` | Configure the OpenAI API key and model for our conversation. | | `/add-papers ` | Upload and process documents for our conversation. | | `/library` | View an index of processed documents to easily navigate your research. | | `/view-snip ` | View the content of a specific snnippet. | | `/session-expense` | Calculate the cost of our conversation, ensuring transparency in resource usage. | | `/export` | Download conversation data for your records or further analysis. | | `/auto-insight ` | Automatically generate questions and answers for the paper. | | `/condense-summary ` | Generate increasingly concise, entity-dense summaries of the paper. | | `/insight-bullets ` | Extract and summarize key insights, methods, results, and conclusions. | | `/insight-mind-map ` | Create a structured outline of the key insights in Markdown format. | | `/paper-synopsis ` | Generate a synopsis of the paper. | | `/deep-dive [] ` | Query me with a specific context. | | `/summarise-section []
` | Summarize a specific section of the paper. | | `/tldr ` | Generate a tldr summary of the paper. |
Feel free to use these commands to enhance your research experience. Let's embark on this exciting journey of discovery together! Use `/help-me` at any point of time to view this guide again. """ def process_documents_wrapper(inputs): if inputs == []: raise InvalidArgumentError("Please provide document urls") snippets, documents = process_documents(inputs) st.session_state.retriever = create_retriever(snippets) st.session_state.source_doc_urls = inputs st.session_state.index = [ [ snip.metadata["chunk_id"], snip.metadata["header"], num_tokens(snip.page_content), ] for snip in snippets ] response = f"Uploaded and processed documents {inputs}" st.session_state.messages.append((f"/add-papers {inputs}", response, "identity")) st.session_state.documents = documents return (response, "identity") def index_documents_wrapper(inputs=None): response = pd.DataFrame( st.session_state.index, columns=["id", "reference", "tokens"] ) st.session_state.messages.append(("/library", response, "dataframe")) return (response, "dataframe") def view_document_wrapper(inputs): response = st.session_state.documents[inputs].page_content st.session_state.messages.append((f"/view-snip {inputs}", response, "identity")) return (response, "identity") def calculate_cost_wrapper(inputs=None): try: stats_df = pd.DataFrame(st.session_state.costing) stats_df.loc["total"] = stats_df.sum() response = stats_df except ValueError: response = "No cost incurred yet" st.session_state.messages.append(("/session-expense", response, "dataframe")) return (response, "dataframe") def download_conversation_wrapper(inputs=None): conversation_data = json.dumps( { "document_urls": ( st.session_state.source_doc_urls if "source_doc_urls" in st.session_state else [] ), "document_snippets": ( st.session_state.index if "index" in st.session_state else [] ), "conversation": [ {"human": message[0], "ai": jsonify_functions[message[2]](message[1])} for message in st.session_state.messages ], "costing": ( st.session_state.costing if "costing" in st.session_state else [] ), "total_cost": ( { k: sum(d[k] for d in st.session_state.costing) for k in st.session_state.costing[0] } if "costing" in st.session_state and len(st.session_state.costing) > 0 else {} ), } ) conversation_data = base64.b64encode(conversation_data.encode()).decode() st.session_state.messages.append( ("/export", "Conversation data downloaded", "identity") ) return ( f'Download Conversation', "identity", ) def query_llm(inputs, relevant_docs): with get_openai_callback() as cb: response = ( qa_chain(ChatOpenAI(model=st.session_state.model, temperature=0)) .invoke({"context": format_docs(relevant_docs), "question": inputs}) .content ) stats = cb response = parse_model_response(response) answer = response["answer"] citations = response["citations"] citations.append( { "source_id": " ".join( [ f"[{ref}]" for ref in sorted( [str(ref.metadata["chunk_id"]) for ref in relevant_docs], ) ] ), "quote": "other sources", } ) st.session_state.messages.append( (inputs, {"answer": answer, "citations": citations}, "reponse_with_citations") ) st.session_state.costing.append( { "prompt tokens": stats.prompt_tokens, "completion tokens": stats.completion_tokens, "cost": stats.total_cost, } ) return ({"answer": answer, "citations": citations}, "reponse_with_citations") def rag_llm_wrapper(inputs): retriever = st.session_state.retriever relevant_docs = retriever.get_relevant_documents(inputs) return query_llm(inputs, relevant_docs) def query_llm_wrapper(inputs): context, question = parse_context_and_question(inputs) relevant_docs = [st.session_state.documents[c] for c in context] return query_llm(question, relevant_docs) def summarise_wrapper(inputs): context, query = parse_context_and_question(inputs) document = [st.session_state.documents[c] for c in context] llm = ChatOpenAI(model=st.session_state.model, temperature=0) with get_openai_callback() as cb: summary = summary_chain(llm).invoke({"section_name": query, "paper": document}) stats = cb st.session_state.messages.append( (f"/summarise-section {query}", summary, "identity") ) st.session_state.costing.append( { "prompt tokens": stats.prompt_tokens, "completion tokens": stats.completion_tokens, "cost": stats.total_cost, } ) return (summary, "identity") def chain_of_density_wrapper(inputs): if inputs == []: raise InvalidArgumentError("Please provide snippet ids") document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) llm = ChatOpenAI(model=st.session_state.model, temperature=0) with get_openai_callback() as cb: summary = chain_of_density_chain(llm).invoke({"paper": document}) stats = cb st.session_state.messages.append(("/condense-summary", summary, "identity")) st.session_state.costing.append( { "prompt tokens": stats.prompt_tokens, "completion tokens": stats.completion_tokens, "cost": stats.total_cost, } ) return (summary, "identity") def synopsis_wrapper(inputs): if inputs == []: raise InvalidArgumentError("Please provide snippet ids") document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) llm = ChatOpenAI(model=st.session_state.model, temperature=0) with get_openai_callback() as cb: summary = synopsis_chain(llm).invoke({"paper": document}) stats = cb st.session_state.messages.append(("/paper-synopsis", summary, "identity")) st.session_state.costing.append( { "prompt tokens": stats.prompt_tokens, "completion tokens": stats.completion_tokens, "cost": stats.total_cost, } ) return (summary, "identity") def tldr_wrapper(inputs): if inputs == []: raise InvalidArgumentError("Please provide snippet ids") document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) llm = ChatOpenAI(model=st.session_state.model, temperature=0) with get_openai_callback() as cb: summary = tldr_chain(llm).invoke({"paper": document}) stats = cb st.session_state.messages.append(("/tldr", summary, "identity")) st.session_state.costing.append( { "prompt tokens": stats.prompt_tokens, "completion tokens": stats.completion_tokens, "cost": stats.total_cost, } ) return (summary, "identity") def insights_bullet_wrapper(inputs): if inputs == []: raise InvalidArgumentError("Please provide snippet ids") document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) llm = ChatOpenAI(model=st.session_state.model, temperature=0) with get_openai_callback() as cb: insights = insights_bullet_chain(llm).invoke({"paper": document}) stats = cb st.session_state.messages.append(("/insight-bullets", insights, "identity")) st.session_state.costing.append( { "prompt tokens": stats.prompt_tokens, "completion tokens": stats.completion_tokens, "cost": stats.total_cost, } ) return (insights, "identity") def insights_mind_map_wrapper(inputs): if inputs == []: raise InvalidArgumentError("Please provide snippet ids") document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) llm = ChatOpenAI(model=st.session_state.model, temperature=0) with get_openai_callback() as cb: insights = insights_mind_map_chain(llm).invoke({"paper": document}) stats = cb st.session_state.messages.append(("/insight-mind-map", insights, "identity")) st.session_state.costing.append( { "prompt tokens": stats.prompt_tokens, "completion tokens": stats.completion_tokens, "cost": stats.total_cost, } ) return (insights, "identity") def auto_qa_chain_wrapper(inputs): if inputs == []: raise InvalidArgumentError("Please provide snippet ids") document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) llm = ChatOpenAI(model=st.session_state.model, temperature=0) retriever = st.session_state.retriever formatted_response = "" with get_openai_callback() as cb: auto_qa_response = auto_qa_chain(llm).invoke({"paper": document}) stats = cb for section in auto_qa_response: section_name = section["section_name"] formatted_response += f"# {section_name}\n" for question in section["questions"]: response = ( qa_chain(ChatOpenAI(model=st.session_state.model, temperature=0)) .invoke( { "context": format_docs( retriever.get_relevant_documents(question) ), "question": question, } ) .content ) answer = parse_model_response(response)["answer"] formatted_response += f"## {question}\n" formatted_response += f"* {answer}\n" formatted_response = "```\n" + formatted_response + "\n```" st.session_state.messages.append( (f"/auto-insight {inputs}", formatted_response, "identity") ) st.session_state.costing.append( { "prompt tokens": stats.prompt_tokens, "completion tokens": stats.completion_tokens, "cost": stats.total_cost, } ) return ( formatted_response, "identity", ) def boot(command_center, formating_functions): st.write("# Agent Zeta") if "costing" not in st.session_state: st.session_state.costing = [] if "messages" not in st.session_state: st.session_state.messages = [] st.chat_message("ai").write(welcome_message, unsafe_allow_html=True) for message in st.session_state.messages: st.chat_message("human").write(message[0]) st.chat_message("ai").write( formating_functions[message[2]](message[1]), unsafe_allow_html=True ) if query := st.chat_input(): try: st.chat_message("human").write(query) response, format_fn_name = command_center.execute_command(query) st.chat_message("ai").write( formating_functions[format_fn_name](response), unsafe_allow_html=True ) except (InvalidArgumentError, InvalidCommandError) as e: st.error(e) def configure_openai_wrapper(inputs): args = openai_parser.parse_args(inputs.split()) os.environ["OPENAI_API_KEY"] = args.key st.session_state.model = args.model st.session_state.messages.append(("/configure", "Configurations Saved", "identity")) return (str(args), "identity") if __name__ == "__main__": all_commands = [ ("/configure", str, configure_openai_wrapper), ("/add-papers", list, process_documents_wrapper), ("/library", None, index_documents_wrapper), ("/view-snip", str, view_document_wrapper), ("/session-expense", None, calculate_cost_wrapper), ("/export", None, download_conversation_wrapper), ("/help-me", None, lambda x: (welcome_message, "identity")), ("/auto-insight", list, auto_qa_chain_wrapper), ("/deep-dive", str, query_llm_wrapper), ("/condense-summary", list, chain_of_density_wrapper), ("/insight-bullets", list, insights_bullet_wrapper), ("/insight-mind-map", list, insights_mind_map_wrapper), ("/paper-synopsis", list, synopsis_wrapper), ("/summarise-section", str, summarise_wrapper), ("/tldr", list, tldr_wrapper), ] command_center = CommandCenter( default_input_type=str, default_function=rag_llm_wrapper, all_commands=all_commands, ) formating_functions = { "identity": lambda x: x, "dataframe": lambda x: x, "reponse_with_citations": lambda x: ai_response_format( x["answer"], x["citations"] ), } jsonify_functions = { "identity": lambda x: x, "dataframe": lambda x: ( x.to_dict(orient="records") if isinstance(x, pd.DataFrame) or isinstance(x, pd.Series) else x ), "reponse_with_citations": lambda x: x, } boot(command_center, formating_functions)