File size: 16,390 Bytes
7e4014b
60e8923
d2b4a56
7e4014b
9bb602c
7e4014b
 
d2b4a56
 
dd7f91e
c323312
 
 
 
 
 
 
709e378
c323312
eed34fa
4dc84dc
eed34fa
c323312
bad08ae
759b946
709e378
60e8923
2159374
60e8923
827774a
dd7f91e
bad08ae
 
dd7f91e
 
 
 
 
bad08ae
c323312
 
ad4cd29
c323312
 
ad4cd29
 
 
4dc84dc
ad4cd29
 
759b946
bcc614e
c323312
dd7f91e
 
 
 
 
c323312
dd7f91e
7e4014b
 
 
26f9473
 
6ae5e8b
7e4014b
 
5f9938a
9bb602c
 
 
 
 
 
5f9938a
7e4014b
6ae5e8b
5ab59a4
20c0b83
 
bcc614e
9bb602c
 
 
5ab59a4
c323312
7e4014b
 
bad08ae
 
ad4cd29
bad08ae
 
 
7e4014b
 
 
 
c323312
7e4014b
fefc5e6
c323312
 
7e4014b
 
 
 
20c0b83
 
 
 
 
 
 
00bc7cc
20c0b83
 
c323312
20c0b83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd7f91e
c323312
 
 
 
 
 
 
60e8923
 
c323312
d2b4a56
c323312
bad08ae
c323312
 
 
d2b4a56
827774a
 
 
 
 
 
 
 
 
26f9473
827774a
 
 
 
 
 
 
c323312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759b946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c323312
ad4cd29
 
 
bad08ae
c323312
 
 
 
7e4014b
 
 
 
 
2159374
7e4014b
c323312
2159374
 
eed34fa
ad4cd29
 
 
bad08ae
eed34fa
 
 
 
 
 
 
 
 
 
 
 
 
 
709e378
bcc614e
 
 
 
 
709e378
 
bcc614e
709e378
 
 
 
 
 
 
 
 
 
 
 
eed34fa
ad4cd29
 
eed34fa
bad08ae
eed34fa
 
 
 
 
 
 
 
 
 
 
 
 
 
4dc84dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ae5e8b
ad4cd29
 
 
bad08ae
a6a480f
 
6ae5e8b
a6a480f
6ae5e8b
a6a480f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ae5e8b
a6a480f
6ae5e8b
 
 
 
 
 
 
 
c323312
a6a480f
c323312
6ae5e8b
 
 
c323312
e2b472e
d2b4a56
 
60e8923
 
dd7f91e
60e8923
d2b4a56
dd7f91e
c323312
dd7f91e
d2b4a56
c323312
 
 
dd7f91e
c323312
dd7f91e
c323312
 
7e4014b
b05ff4a
bad08ae
 
 
 
4d09ba0
bad08ae
 
 
7e4014b
 
bad08ae
c323312
 
ad4cd29
c323312
 
 
ad4cd29
c323312
ad4cd29
eed34fa
4dc84dc
ad4cd29
759b946
bcc614e
7e4014b
 
 
c323312
00bc7cc
7e4014b
c323312
 
 
 
 
 
 
 
 
ad4cd29
c44c8ed
 
 
ad4cd29
c323312
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
import streamlit as st
import os
import pandas as pd
from command_center import CommandCenter
from process_documents import process_documents, num_tokens
from embed_documents import create_retriever
import json
from langchain.callbacks import get_openai_callback
from langchain_openai import ChatOpenAI
import base64
from chat_chains import (
    parse_model_response,
    qa_chain,
    format_docs,
    parse_context_and_question,
    ai_response_format,
)
from autoqa_chain import auto_qa_chain
from chain_of_density import chain_of_density_chain
from insights_bullet_chain import insights_bullet_chain
from insights_mind_map_chain import insights_mind_map_chain
from synopsis_chain import synopsis_chain
from custom_exceptions import InvalidArgumentError, InvalidCommandError
from openai_configuration import openai_parser
from summary_chain import summary_chain
from tldr_chain import tldr_chain

st.set_page_config(layout="wide")


welcome_message = """
Hi I'm Agent Zeta, your AI assistant, dedicated to making your journey through machine learning research papers as insightful and interactive as possible. 
Whether you're diving into the latest studies or brushing up on foundational papers, I'm here to help navigate, discuss, and analyze content with you.

Here's a quick guide to getting started with me:

| Command | Description |
|---------|-------------|
| `/configure --key <api key> --model <model>` | Configure the OpenAI API key and model for our conversation. |
| `/add-papers <list of urls>` | Upload and process documents for our conversation. |
| `/library` | View an index of processed documents to easily navigate your research. |
| `/view-snip <snippet id>` | View the content of a specific snnippet. |
| `/session-expense` | Calculate the cost of our conversation, ensuring transparency in resource usage. |
| `/export` | Download conversation data for your records or further analysis. |
| `/auto-insight <list of snippet ids>` | Automatically generate questions and answers for the paper. |
| `/condense-summary <list of snippet ids>` | Generate increasingly concise, entity-dense summaries of the paper. |
| `/insight-bullets <list of snippet ids>` | Extract and summarize key insights, methods, results, and conclusions. |
| `/insight-mind-map <list of snippet ids>` | Create a structured outline of the key insights in Markdown format. |
| `/paper-synopsis <list of snippet ids>` | Generate a synopsis of the paper. |
| `/deep-dive [<list of snippet ids>] <query>` | Query me with a specific context. |
| `/summarise-section [<list of snippet ids>] <section name>` | Summarize a specific section of the paper. |
| `/tldr [<list of snippet ids>] <query>` | Generate a tldr summary of the paper. |


<br>

Feel free to use these commands to enhance your research experience. Let's embark on this exciting journey of discovery together!

Use `/help-me` at any point of time to view this guide again.
"""


def process_documents_wrapper(inputs):
    if inputs == []:
        raise InvalidArgumentError("Please provide document urls")
    snippets, documents = process_documents(inputs)
    st.session_state.retriever = create_retriever(snippets)
    st.session_state.source_doc_urls = inputs
    st.session_state.index = [
        [
            snip.metadata["chunk_id"],
            snip.metadata["header"],
            num_tokens(snip.page_content),
        ]
        for snip in snippets
    ]
    response = f"Uploaded and processed documents {inputs}"
    st.session_state.documents = documents
    return index_documents_wrapper(None, f"/add-papers {inputs}")


def index_documents_wrapper(inputs=None, arg="/library"):
    response = pd.DataFrame(
        st.session_state.index, columns=["id", "reference", "tokens"]
    )
    st.session_state.messages.append((arg, response, "dataframe"))
    return (response, "dataframe")


def view_document_wrapper(inputs):
    response = st.session_state.documents[inputs].page_content
    st.session_state.messages.append((f"/view-snip {inputs}", response, "identity"))
    return (response, "identity")


def calculate_cost_wrapper(inputs=None):
    try:
        stats_df = pd.DataFrame(st.session_state.costing)
        stats_df.loc["total"] = stats_df.sum()
        response = stats_df
    except ValueError:
        response = "No cost incurred yet"
    st.session_state.messages.append(("/session-expense", response, "dataframe"))
    return (response, "dataframe")


def download_conversation_wrapper(inputs=None):
    conversation_data = json.dumps(
        {
            "document_urls": (
                st.session_state.source_doc_urls
                if "source_doc_urls" in st.session_state
                else []
            ),
            "document_snippets": (
                st.session_state.index if "index" in st.session_state else []
            ),
            "conversation": [
                {"human": message[0], "ai": jsonify_functions[message[2]](message[1])}
                for message in st.session_state.messages
            ],
            "costing": (
                st.session_state.costing if "costing" in st.session_state else []
            ),
            "total_cost": (
                {
                    k: sum(d[k] for d in st.session_state.costing)
                    for k in st.session_state.costing[0]
                }
                if "costing" in st.session_state and len(st.session_state.costing) > 0
                else {}
            ),
        }
    )
    conversation_data = base64.b64encode(conversation_data.encode()).decode()
    st.session_state.messages.append(
        ("/export", "Conversation data downloaded", "identity")
    )
    return (
        f'<a href="data:text/csv;base64,{conversation_data}" download="conversation_data.json">Download Conversation</a>',
        "identity",
    )


def query_llm(inputs, relevant_docs):
    with get_openai_callback() as cb:
        response = (
            qa_chain(ChatOpenAI(model=st.session_state.model, temperature=0))
            .invoke({"context": format_docs(relevant_docs), "question": inputs})
            .content
        )
        stats = cb
    response = parse_model_response(response)
    answer = response["answer"]
    citations = response["citations"]
    citations.append(
        {
            "source_id": " ".join(
                [
                    f"[{ref}]"
                    for ref in sorted(
                        [str(ref.metadata["chunk_id"]) for ref in relevant_docs],
                    )
                ]
            ),
            "quote": "other sources",
        }
    )

    st.session_state.messages.append(
        (inputs, {"answer": answer, "citations": citations}, "reponse_with_citations")
    )
    st.session_state.costing.append(
        {
            "prompt tokens": stats.prompt_tokens,
            "completion tokens": stats.completion_tokens,
            "cost": stats.total_cost,
        }
    )
    return ({"answer": answer, "citations": citations}, "reponse_with_citations")


def rag_llm_wrapper(inputs):
    retriever = st.session_state.retriever
    relevant_docs = retriever.get_relevant_documents(inputs)
    return query_llm(inputs, relevant_docs)


def query_llm_wrapper(inputs):
    context, question = parse_context_and_question(inputs)
    relevant_docs = [st.session_state.documents[c] for c in context]
    return query_llm(question, relevant_docs)


def summarise_wrapper(inputs):
    context, query = parse_context_and_question(inputs)
    document = [st.session_state.documents[c] for c in context]
    llm = ChatOpenAI(model=st.session_state.model, temperature=0)
    with get_openai_callback() as cb:
        summary = summary_chain(llm).invoke({"section_name": query, "paper": document})
        stats = cb
    st.session_state.messages.append(
        (f"/summarise-section {query}", summary, "identity")
    )
    st.session_state.costing.append(
        {
            "prompt tokens": stats.prompt_tokens,
            "completion tokens": stats.completion_tokens,
            "cost": stats.total_cost,
        }
    )
    return (summary, "identity")


def chain_of_density_wrapper(inputs):
    if inputs == []:
        raise InvalidArgumentError("Please provide snippet ids")
    document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs])
    llm = ChatOpenAI(model=st.session_state.model, temperature=0)
    with get_openai_callback() as cb:
        summary = chain_of_density_chain(llm).invoke({"paper": document})
        stats = cb
    st.session_state.messages.append(("/condense-summary", summary, "identity"))
    st.session_state.costing.append(
        {
            "prompt tokens": stats.prompt_tokens,
            "completion tokens": stats.completion_tokens,
            "cost": stats.total_cost,
        }
    )
    return (summary, "identity")


def synopsis_wrapper(inputs):
    if inputs == []:
        raise InvalidArgumentError("Please provide snippet ids")
    document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs])
    llm = ChatOpenAI(model=st.session_state.model, temperature=0)
    with get_openai_callback() as cb:
        summary = synopsis_chain(llm).invoke({"paper": document})
        stats = cb
    st.session_state.messages.append(("/paper-synopsis", summary, "identity"))
    st.session_state.costing.append(
        {
            "prompt tokens": stats.prompt_tokens,
            "completion tokens": stats.completion_tokens,
            "cost": stats.total_cost,
        }
    )
    return (summary, "identity")


def tldr_wrapper(inputs):
    print(inputs)
    context, query = parse_context_and_question(inputs)
    document = "\n\n".join(
        [st.session_state.documents[c].page_content for c in context]
    )
    llm = ChatOpenAI(model=st.session_state.model, temperature=0)
    with get_openai_callback() as cb:
        summary = tldr_chain(llm).invoke({"title": query, "paper": document})
        stats = cb
    st.session_state.messages.append(("/tldr", summary, "identity"))
    st.session_state.costing.append(
        {
            "prompt tokens": stats.prompt_tokens,
            "completion tokens": stats.completion_tokens,
            "cost": stats.total_cost,
        }
    )
    return (summary, "identity")


def insights_bullet_wrapper(inputs):
    if inputs == []:
        raise InvalidArgumentError("Please provide snippet ids")
    document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs])
    llm = ChatOpenAI(model=st.session_state.model, temperature=0)
    with get_openai_callback() as cb:
        insights = insights_bullet_chain(llm).invoke({"paper": document})
        stats = cb
    st.session_state.messages.append(("/insight-bullets", insights, "identity"))
    st.session_state.costing.append(
        {
            "prompt tokens": stats.prompt_tokens,
            "completion tokens": stats.completion_tokens,
            "cost": stats.total_cost,
        }
    )
    return (insights, "identity")


def insights_mind_map_wrapper(inputs):
    if inputs == []:
        raise InvalidArgumentError("Please provide snippet ids")
    document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs])
    llm = ChatOpenAI(model=st.session_state.model, temperature=0)
    with get_openai_callback() as cb:
        insights = insights_mind_map_chain(llm).invoke({"paper": document})
        stats = cb
    st.session_state.messages.append(("/insight-mind-map", insights, "identity"))
    st.session_state.costing.append(
        {
            "prompt tokens": stats.prompt_tokens,
            "completion tokens": stats.completion_tokens,
            "cost": stats.total_cost,
        }
    )
    return (insights, "identity")


def auto_qa_chain_wrapper(inputs):
    if inputs == []:
        raise InvalidArgumentError("Please provide snippet ids")
    document = "\n\n".join([st.session_state.documents[c].page_content for c in inputs])
    llm = ChatOpenAI(model=st.session_state.model, temperature=0)
    retriever = st.session_state.retriever
    formatted_response = ""
    with get_openai_callback() as cb:
        auto_qa_response = auto_qa_chain(llm).invoke({"paper": document})
        stats = cb
        for section in auto_qa_response:
            section_name = section["section_name"]
            formatted_response += f"# {section_name}\n"
            for question in section["questions"]:
                response = (
                    qa_chain(ChatOpenAI(model=st.session_state.model, temperature=0))
                    .invoke(
                        {
                            "context": format_docs(
                                retriever.get_relevant_documents(question)
                            ),
                            "question": question,
                        }
                    )
                    .content
                )
                answer = parse_model_response(response)["answer"]
                formatted_response += f"## {question}\n"
                formatted_response += f"* {answer}\n"
    formatted_response = "```\n" + formatted_response + "\n```"

    st.session_state.messages.append(
        (f"/auto-insight {inputs}", formatted_response, "identity")
    )
    st.session_state.costing.append(
        {
            "prompt tokens": stats.prompt_tokens,
            "completion tokens": stats.completion_tokens,
            "cost": stats.total_cost,
        }
    )
    return (
        formatted_response,
        "identity",
    )


def boot(command_center, formating_functions):
    st.write("# Agent Zeta")
    if "costing" not in st.session_state:
        st.session_state.costing = []
    if "messages" not in st.session_state:
        st.session_state.messages = []
    st.chat_message("ai").write(welcome_message, unsafe_allow_html=True)
    for message in st.session_state.messages:
        st.chat_message("human").write(message[0])
        st.chat_message("ai").write(
            formating_functions[message[2]](message[1]), unsafe_allow_html=True
        )
    if query := st.chat_input():
        try:
            st.chat_message("human").write(query)
            response, format_fn_name = command_center.execute_command(query)
            st.chat_message("ai").write(
                formating_functions[format_fn_name](response), unsafe_allow_html=True
            )
        except (InvalidArgumentError, InvalidCommandError) as e:
            st.error(e)


def configure_openai_wrapper(inputs):
    args = openai_parser.parse_args(inputs.split())
    os.environ["OPENAI_API_KEY"] = args.key
    st.session_state.model = args.model
    st.session_state.messages.append(("/configure", "Configurations Saved", "identity"))
    return (str(args), "identity")


if __name__ == "__main__":
    all_commands = [
        ("/configure", str, configure_openai_wrapper),
        ("/add-papers", list, process_documents_wrapper),
        ("/library", None, index_documents_wrapper),
        ("/view-snip", str, view_document_wrapper),
        ("/session-expense", None, calculate_cost_wrapper),
        ("/export", None, download_conversation_wrapper),
        ("/help-me", None, lambda x: (welcome_message, "identity")),
        ("/auto-insight", list, auto_qa_chain_wrapper),
        ("/deep-dive", str, query_llm_wrapper),
        ("/condense-summary", list, chain_of_density_wrapper),
        ("/insight-bullets", list, insights_bullet_wrapper),
        ("/insight-mind-map", list, insights_mind_map_wrapper),
        ("/paper-synopsis", list, synopsis_wrapper),
        ("/summarise-section", str, summarise_wrapper),
        ("/tldr", str, tldr_wrapper),
    ]
    command_center = CommandCenter(
        default_input_type=str,
        default_function=rag_llm_wrapper,
        all_commands=all_commands,
    )
    formating_functions = {
        "identity": lambda x: x,
        "dataframe": lambda x: x,
        "reponse_with_citations": lambda x: ai_response_format(
            x["answer"], x["citations"]
        ),
    }
    jsonify_functions = {
        "identity": lambda x: x,
        "dataframe": lambda x: (
            x.to_dict(orient="records")
            if isinstance(x, pd.DataFrame) or isinstance(x, pd.Series)
            else x
        ),
        "reponse_with_citations": lambda x: x,
    }
    boot(command_center, formating_functions)