|
import streamlit as st |
|
import os |
|
import pandas as pd |
|
from command_center import CommandCenter |
|
from process_documents import process_documents, num_tokens |
|
from embed_documents import create_retriever |
|
import json |
|
from langchain.callbacks import get_openai_callback |
|
from langchain_openai import ChatOpenAI |
|
import base64 |
|
from chat_chains import ( |
|
parse_model_response, |
|
qa_chain, |
|
format_docs, |
|
parse_context_and_question, |
|
ai_response_format, |
|
) |
|
from autoqa_chain import auto_qa_chain |
|
from chain_of_density import chain_of_density_chain |
|
from insights_bullet_chain import insights_bullet_chain |
|
from insights_mind_map_chain import insights_mind_map_chain |
|
from synopsis_chain import synopsis_chain |
|
from custom_exceptions import InvalidArgumentError, InvalidCommandError |
|
from openai_configuration import openai_parser |
|
from summary_chain import summary_chain |
|
from tldr_chain import tldr_chain |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
welcome_message = """ |
|
Hi I'm Agent Zeta, your AI assistant, dedicated to making your journey through machine learning research papers as insightful and interactive as possible. |
|
Whether you're diving into the latest studies or brushing up on foundational papers, I'm here to help navigate, discuss, and analyze content with you. |
|
|
|
Here's a quick guide to getting started with me: |
|
|
|
| Command | Description | |
|
|---------|-------------| |
|
| `/configure --key <api key> --model <model>` | Configure the OpenAI API key and model for our conversation. | |
|
| `/add-papers <list of urls>` | Upload and process documents for our conversation. | |
|
| `/library` | View an index of processed documents to easily navigate your research. | |
|
| `/view-snip <snippet id>` | View the content of a specific snnippet. | |
|
| `/session-expense` | Calculate the cost of our conversation, ensuring transparency in resource usage. | |
|
| `/export` | Download conversation data for your records or further analysis. | |
|
| `/auto-insight <list of snippet ids>` | Automatically generate questions and answers for the paper. | |
|
| `/condense-summary <list of snippet ids>` | Generate increasingly concise, entity-dense summaries of the paper. | |
|
| `/insight-bullets <list of snippet ids>` | Extract and summarize key insights, methods, results, and conclusions. | |
|
| `/insight-mind-map <list of snippet ids>` | Create a structured outline of the key insights in Markdown format. | |
|
| `/paper-synopsis <list of snippet ids>` | Generate a synopsis of the paper. | |
|
| `/deep-dive [<list of snippet ids>] <query>` | Query me with a specific context. | |
|
| `/summarise-section [<list of snippet ids>] <section name>` | Summarize a specific section of the paper. | |
|
| `/tldr [<list of snippet ids>] <query>` | Generate a tldr summary of the paper. | |
|
|
|
|
|
<br> |
|
|
|
Feel free to use these commands to enhance your research experience. Let's embark on this exciting journey of discovery together! |
|
|
|
Use `/help-me` at any point of time to view this guide again. |
|
""" |
|
|
|
|
|
def process_documents_wrapper(inputs): |
|
if inputs == []: |
|
raise InvalidArgumentError("Please provide document urls") |
|
snippets, documents = process_documents(inputs) |
|
st.session_state.retriever = create_retriever(snippets) |
|
st.session_state.source_doc_urls = inputs |
|
st.session_state.index = [ |
|
[ |
|
snip.metadata["chunk_id"], |
|
snip.metadata["header"], |
|
num_tokens(snip.page_content), |
|
] |
|
for snip in snippets |
|
] |
|
response = f"Uploaded and processed documents {inputs}" |
|
st.session_state.documents = documents |
|
return index_documents_wrapper(None, f"/add-papers {inputs}") |
|
|
|
|
|
def index_documents_wrapper(inputs=None, arg="/library"): |
|
response = pd.DataFrame( |
|
st.session_state.index, columns=["id", "reference", "tokens"] |
|
) |
|
st.session_state.messages.append((arg, response, "dataframe")) |
|
return (response, "dataframe") |
|
|
|
|
|
def view_document_wrapper(inputs): |
|
response = st.session_state.documents[inputs].page_content |
|
st.session_state.messages.append((f"/view-snip {inputs}", response, "identity")) |
|
return (response, "identity") |
|
|
|
|
|
def calculate_cost_wrapper(inputs=None): |
|
try: |
|
stats_df = pd.DataFrame(st.session_state.costing) |
|
stats_df.loc["total"] = stats_df.sum() |
|
response = stats_df |
|
except ValueError: |
|
response = "No cost incurred yet" |
|
st.session_state.messages.append(("/session-expense", response, "dataframe")) |
|
return (response, "dataframe") |
|
|
|
|
|
def download_conversation_wrapper(inputs=None): |
|
conversation_data = json.dumps( |
|
{ |
|
"document_urls": ( |
|
st.session_state.source_doc_urls |
|
if "source_doc_urls" in st.session_state |
|
else [] |
|
), |
|
"document_snippets": ( |
|
st.session_state.index if "index" in st.session_state else [] |
|
), |
|
"conversation": [ |
|
{"human": message[0], "ai": jsonify_functions[message[2]](message[1])} |
|
for message in st.session_state.messages |
|
], |
|
"costing": ( |
|
st.session_state.costing if "costing" in st.session_state else [] |
|
), |
|
"total_cost": ( |
|
{ |
|
k: sum(d[k] for d in st.session_state.costing) |
|
for k in st.session_state.costing[0] |
|
} |
|
if "costing" in st.session_state and len(st.session_state.costing) > 0 |
|
else {} |
|
), |
|
} |
|
) |
|
conversation_data = base64.b64encode(conversation_data.encode()).decode() |
|
st.session_state.messages.append( |
|
("/export", "Conversation data downloaded", "identity") |
|
) |
|
return ( |
|
f'<a href="data:text/csv;base64,{conversation_data}" download="conversation_data.json">Download Conversation</a>', |
|
"identity", |
|
) |
|
|
|
|
|
def query_llm(inputs, relevant_docs): |
|
with get_openai_callback() as cb: |
|
response = ( |
|
qa_chain(ChatOpenAI(model=st.session_state.model, temperature=0)) |
|
.invoke({"context": format_docs(relevant_docs), "question": inputs}) |
|
.content |
|
) |
|
stats = cb |
|
response = parse_model_response(response) |
|
answer = response["answer"] |
|
citations = response["citations"] |
|
citations.append( |
|
{ |
|
"source_id": " ".join( |
|
[ |
|
f"[{ref}]" |
|
for ref in sorted( |
|
[str(ref.metadata["chunk_id"]) for ref in relevant_docs], |
|
) |
|
] |
|
), |
|
"quote": "other sources", |
|
} |
|
) |
|
|
|
st.session_state.messages.append( |
|
(inputs, {"answer": answer, "citations": citations}, "reponse_with_citations") |
|
) |
|
st.session_state.costing.append( |
|
{ |
|
"prompt tokens": stats.prompt_tokens, |
|
"completion tokens": stats.completion_tokens, |
|
"cost": stats.total_cost, |
|
} |
|
) |
|
return ({"answer": answer, "citations": citations}, "reponse_with_citations") |
|
|
|
|
|
def rag_llm_wrapper(inputs): |
|
retriever = st.session_state.retriever |
|
relevant_docs = retriever.get_relevant_documents(inputs) |
|
return query_llm(inputs, relevant_docs) |
|
|
|
|
|
def query_llm_wrapper(inputs): |
|
context, question = parse_context_and_question(inputs) |
|
relevant_docs = [st.session_state.documents[c] for c in context] |
|
return query_llm(question, relevant_docs) |
|
|
|
|
|
def summarise_wrapper(inputs): |
|
context, query = parse_context_and_question(inputs) |
|
document = [st.session_state.documents[c] for c in context] |
|
llm = ChatOpenAI(model=st.session_state.model, temperature=0) |
|
with get_openai_callback() as cb: |
|
summary = summary_chain(llm).invoke({"section_name": query, "paper": document}) |
|
stats = cb |
|
st.session_state.messages.append( |
|
(f"/summarise-section {query}", summary, "identity") |
|
) |
|
st.session_state.costing.append( |
|
{ |
|
"prompt tokens": stats.prompt_tokens, |
|
"completion tokens": stats.completion_tokens, |
|
"cost": stats.total_cost, |
|
} |
|
) |
|
return (summary, "identity") |
|
|
|
|
|
def chain_of_density_wrapper(inputs): |
|
if inputs == []: |
|
raise InvalidArgumentError("Please provide snippet ids") |
|
document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) |
|
llm = ChatOpenAI(model=st.session_state.model, temperature=0) |
|
with get_openai_callback() as cb: |
|
summary = chain_of_density_chain(llm).invoke({"paper": document}) |
|
stats = cb |
|
st.session_state.messages.append(("/condense-summary", summary, "identity")) |
|
st.session_state.costing.append( |
|
{ |
|
"prompt tokens": stats.prompt_tokens, |
|
"completion tokens": stats.completion_tokens, |
|
"cost": stats.total_cost, |
|
} |
|
) |
|
return (summary, "identity") |
|
|
|
|
|
def synopsis_wrapper(inputs): |
|
if inputs == []: |
|
raise InvalidArgumentError("Please provide snippet ids") |
|
document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) |
|
llm = ChatOpenAI(model=st.session_state.model, temperature=0) |
|
with get_openai_callback() as cb: |
|
summary = synopsis_chain(llm).invoke({"paper": document}) |
|
stats = cb |
|
st.session_state.messages.append(("/paper-synopsis", summary, "identity")) |
|
st.session_state.costing.append( |
|
{ |
|
"prompt tokens": stats.prompt_tokens, |
|
"completion tokens": stats.completion_tokens, |
|
"cost": stats.total_cost, |
|
} |
|
) |
|
return (summary, "identity") |
|
|
|
|
|
def tldr_wrapper(inputs): |
|
print(inputs) |
|
context, query = parse_context_and_question(inputs) |
|
document = "\n\n".join( |
|
[st.session_state.documents[c].page_content for c in context] |
|
) |
|
llm = ChatOpenAI(model=st.session_state.model, temperature=0) |
|
with get_openai_callback() as cb: |
|
summary = tldr_chain(llm).invoke({"title": query, "paper": document}) |
|
stats = cb |
|
st.session_state.messages.append(("/tldr", summary, "identity")) |
|
st.session_state.costing.append( |
|
{ |
|
"prompt tokens": stats.prompt_tokens, |
|
"completion tokens": stats.completion_tokens, |
|
"cost": stats.total_cost, |
|
} |
|
) |
|
return (summary, "identity") |
|
|
|
|
|
def insights_bullet_wrapper(inputs): |
|
if inputs == []: |
|
raise InvalidArgumentError("Please provide snippet ids") |
|
document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) |
|
llm = ChatOpenAI(model=st.session_state.model, temperature=0) |
|
with get_openai_callback() as cb: |
|
insights = insights_bullet_chain(llm).invoke({"paper": document}) |
|
stats = cb |
|
st.session_state.messages.append(("/insight-bullets", insights, "identity")) |
|
st.session_state.costing.append( |
|
{ |
|
"prompt tokens": stats.prompt_tokens, |
|
"completion tokens": stats.completion_tokens, |
|
"cost": stats.total_cost, |
|
} |
|
) |
|
return (insights, "identity") |
|
|
|
|
|
def insights_mind_map_wrapper(inputs): |
|
if inputs == []: |
|
raise InvalidArgumentError("Please provide snippet ids") |
|
document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) |
|
llm = ChatOpenAI(model=st.session_state.model, temperature=0) |
|
with get_openai_callback() as cb: |
|
insights = insights_mind_map_chain(llm).invoke({"paper": document}) |
|
stats = cb |
|
st.session_state.messages.append(("/insight-mind-map", insights, "identity")) |
|
st.session_state.costing.append( |
|
{ |
|
"prompt tokens": stats.prompt_tokens, |
|
"completion tokens": stats.completion_tokens, |
|
"cost": stats.total_cost, |
|
} |
|
) |
|
return (insights, "identity") |
|
|
|
|
|
def auto_qa_chain_wrapper(inputs): |
|
if inputs == []: |
|
raise InvalidArgumentError("Please provide snippet ids") |
|
document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs]) |
|
llm = ChatOpenAI(model=st.session_state.model, temperature=0) |
|
retriever = st.session_state.retriever |
|
formatted_response = "" |
|
with get_openai_callback() as cb: |
|
auto_qa_response = auto_qa_chain(llm).invoke({"paper": document}) |
|
stats = cb |
|
for section in auto_qa_response: |
|
section_name = section["section_name"] |
|
formatted_response += f"# {section_name}\n" |
|
for question in section["questions"]: |
|
response = ( |
|
qa_chain(ChatOpenAI(model=st.session_state.model, temperature=0)) |
|
.invoke( |
|
{ |
|
"context": format_docs( |
|
retriever.get_relevant_documents(question) |
|
), |
|
"question": question, |
|
} |
|
) |
|
.content |
|
) |
|
answer = parse_model_response(response)["answer"] |
|
formatted_response += f"## {question}\n" |
|
formatted_response += f"* {answer}\n" |
|
formatted_response = "```\n" + formatted_response + "\n```" |
|
|
|
st.session_state.messages.append( |
|
(f"/auto-insight {inputs}", formatted_response, "identity") |
|
) |
|
st.session_state.costing.append( |
|
{ |
|
"prompt tokens": stats.prompt_tokens, |
|
"completion tokens": stats.completion_tokens, |
|
"cost": stats.total_cost, |
|
} |
|
) |
|
return ( |
|
formatted_response, |
|
"identity", |
|
) |
|
|
|
|
|
def boot(command_center, formating_functions): |
|
st.write("# Agent Zeta") |
|
if "costing" not in st.session_state: |
|
st.session_state.costing = [] |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
st.chat_message("ai").write(welcome_message, unsafe_allow_html=True) |
|
for message in st.session_state.messages: |
|
st.chat_message("human").write(message[0]) |
|
st.chat_message("ai").write( |
|
formating_functions[message[2]](message[1]), unsafe_allow_html=True |
|
) |
|
if query := st.chat_input(): |
|
try: |
|
st.chat_message("human").write(query) |
|
response, format_fn_name = command_center.execute_command(query) |
|
st.chat_message("ai").write( |
|
formating_functions[format_fn_name](response), unsafe_allow_html=True |
|
) |
|
except (InvalidArgumentError, InvalidCommandError) as e: |
|
st.error(e) |
|
|
|
|
|
def configure_openai_wrapper(inputs): |
|
args = openai_parser.parse_args(inputs.split()) |
|
os.environ["OPENAI_API_KEY"] = args.key |
|
st.session_state.model = args.model |
|
st.session_state.messages.append(("/configure", "Configurations Saved", "identity")) |
|
return (str(args), "identity") |
|
|
|
|
|
if __name__ == "__main__": |
|
all_commands = [ |
|
("/configure", str, configure_openai_wrapper), |
|
("/add-papers", list, process_documents_wrapper), |
|
("/library", None, index_documents_wrapper), |
|
("/view-snip", str, view_document_wrapper), |
|
("/session-expense", None, calculate_cost_wrapper), |
|
("/export", None, download_conversation_wrapper), |
|
("/help-me", None, lambda x: (welcome_message, "identity")), |
|
("/auto-insight", list, auto_qa_chain_wrapper), |
|
("/deep-dive", str, query_llm_wrapper), |
|
("/condense-summary", list, chain_of_density_wrapper), |
|
("/insight-bullets", list, insights_bullet_wrapper), |
|
("/insight-mind-map", list, insights_mind_map_wrapper), |
|
("/paper-synopsis", list, synopsis_wrapper), |
|
("/summarise-section", str, summarise_wrapper), |
|
("/tldr", str, tldr_wrapper), |
|
] |
|
command_center = CommandCenter( |
|
default_input_type=str, |
|
default_function=rag_llm_wrapper, |
|
all_commands=all_commands, |
|
) |
|
formating_functions = { |
|
"identity": lambda x: x, |
|
"dataframe": lambda x: x, |
|
"reponse_with_citations": lambda x: ai_response_format( |
|
x["answer"], x["citations"] |
|
), |
|
} |
|
jsonify_functions = { |
|
"identity": lambda x: x, |
|
"dataframe": lambda x: ( |
|
x.to_dict(orient="records") |
|
if isinstance(x, pd.DataFrame) or isinstance(x, pd.Series) |
|
else x |
|
), |
|
"reponse_with_citations": lambda x: x, |
|
} |
|
boot(command_center, formating_functions) |
|
|