Upload 8 files
Browse files- app.py +119 -83
- chain_of_density.py +42 -0
- chat_chains.py +33 -27
- command_center.py +6 -0
- custom_exceptions.py +6 -0
- process_documents.py +20 -9
app.py
CHANGED
@@ -8,32 +8,20 @@ import json
|
|
8 |
from langchain.callbacks import get_openai_callback
|
9 |
from langchain_openai import ChatOpenAI
|
10 |
import base64
|
11 |
-
from chat_chains import
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
st.set_page_config(layout="wide")
|
16 |
os.environ["OPENAI_API_KEY"] = "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"
|
17 |
|
18 |
-
format_citations = lambda citations: "\n\n".join(
|
19 |
-
[f"{citation['quote']} ... [{citation['source_id']}]" for citation in citations]
|
20 |
-
)
|
21 |
-
|
22 |
-
|
23 |
-
def session_state_2_llm_chat_history(session_state):
|
24 |
-
chat_history = []
|
25 |
-
for ss in session_state:
|
26 |
-
if not ss[0].startswith("/"):
|
27 |
-
chat_history.append(HumanMessage(content=ss[0]))
|
28 |
-
chat_history.append(AIMessage(content=ss[1]))
|
29 |
-
return chat_history
|
30 |
-
|
31 |
-
|
32 |
-
ai_message_format = lambda message, references: (
|
33 |
-
f"{message}\n\n---\n\n{format_citations(references)}"
|
34 |
-
if references != ""
|
35 |
-
else message
|
36 |
-
)
|
37 |
|
38 |
welcome_message = """
|
39 |
Hi I'm Agent Zeta, your AI assistant, dedicated to making your journey through machine learning research papers as insightful and interactive as possible. Whether you're diving into the latest studies or brushing up on foundational papers, I'm here to help navigate, discuss, and analyze content with you.
|
@@ -42,17 +30,20 @@ Here's a quick guide to getting started with me:
|
|
42 |
|
43 |
| Command | Description |
|
44 |
|---------|-------------|
|
45 |
-
| `/
|
46 |
-
| `/
|
47 |
-
| `/
|
48 |
-
| `/
|
49 |
-
| `/auto
|
|
|
|
|
|
|
50 |
|
51 |
<br>
|
52 |
|
53 |
Feel free to use these commands to enhance your research experience. Let's embark on this exciting journey of discovery together!
|
54 |
|
55 |
-
Use `/
|
56 |
"""
|
57 |
|
58 |
|
@@ -64,28 +55,26 @@ def process_documents_wrapper(inputs):
|
|
64 |
[snip.metadata["chunk_id"], snip.metadata["header"]] for snip in snippets
|
65 |
]
|
66 |
response = f"Uploaded and processed documents {inputs}"
|
67 |
-
st.session_state.messages.append((f"/
|
68 |
st.session_state.documents = documents
|
69 |
-
return response
|
70 |
|
71 |
|
72 |
def index_documents_wrapper(inputs=None):
|
73 |
-
response = pd.DataFrame(
|
74 |
-
|
75 |
-
|
76 |
-
st.session_state.messages.append(("/index", response, ""))
|
77 |
-
return response
|
78 |
|
79 |
|
80 |
def calculate_cost_wrapper(inputs=None):
|
81 |
try:
|
82 |
stats_df = pd.DataFrame(st.session_state.costing)
|
83 |
stats_df.loc["total"] = stats_df.sum()
|
84 |
-
response = stats_df
|
85 |
except ValueError:
|
86 |
response = "No cost incurred yet"
|
87 |
-
st.session_state.messages.append(("/
|
88 |
-
return response
|
89 |
|
90 |
|
91 |
def download_conversation_wrapper(inputs=None):
|
@@ -100,7 +89,7 @@ def download_conversation_wrapper(inputs=None):
|
|
100 |
st.session_state.index if "index" in st.session_state else []
|
101 |
),
|
102 |
"conversation": [
|
103 |
-
{"human": message[0], "ai": message[
|
104 |
for message in st.session_state.messages
|
105 |
],
|
106 |
"costing": (
|
@@ -117,25 +106,22 @@ def download_conversation_wrapper(inputs=None):
|
|
117 |
}
|
118 |
)
|
119 |
conversation_data = base64.b64encode(conversation_data.encode()).decode()
|
120 |
-
st.session_state.messages.append(
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
|
124 |
-
def
|
125 |
-
retriever = st.session_state.retriever
|
126 |
-
qa_chain = rag_chain(
|
127 |
-
retriever, ChatOpenAI(model="gpt-4-0125-preview", temperature=0)
|
128 |
-
)
|
129 |
-
relevant_docs = retriever.get_relevant_documents(inputs)
|
130 |
with get_openai_callback() as cb:
|
131 |
-
response =
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
),
|
137 |
-
}
|
138 |
-
).content
|
139 |
stats = cb
|
140 |
response = parse_model_response(response)
|
141 |
answer = response["answer"]
|
@@ -147,7 +133,6 @@ def query_llm_wrapper(inputs):
|
|
147 |
f"[{ref}]"
|
148 |
for ref in sorted(
|
149 |
[ref.metadata["chunk_id"] for ref in relevant_docs],
|
150 |
-
key=lambda x: int(x.split("_")[1]),
|
151 |
)
|
152 |
]
|
153 |
),
|
@@ -155,7 +140,41 @@ def query_llm_wrapper(inputs):
|
|
155 |
}
|
156 |
)
|
157 |
|
158 |
-
st.session_state.messages.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
st.session_state.costing.append(
|
160 |
{
|
161 |
"prompt tokens": stats.prompt_tokens,
|
@@ -163,11 +182,13 @@ def query_llm_wrapper(inputs):
|
|
163 |
"cost": stats.total_cost,
|
164 |
}
|
165 |
)
|
166 |
-
return
|
167 |
|
168 |
|
169 |
def auto_qa_chain_wrapper(inputs):
|
170 |
-
|
|
|
|
|
171 |
llm = ChatOpenAI(model="gpt-4-turbo-preview", temperature=0)
|
172 |
auto_qa_conversation = []
|
173 |
with get_openai_callback() as cb:
|
@@ -176,15 +197,15 @@ def auto_qa_chain_wrapper(inputs):
|
|
176 |
"questions"
|
177 |
]
|
178 |
auto_qa_conversation = [
|
179 |
-
(f'/auto {qa["question"]}', qa["answer"], "")
|
180 |
for qa in auto_qa_response_parsed
|
181 |
]
|
182 |
stats = cb
|
183 |
st.session_state.messages.append(
|
184 |
-
(f"/auto {inputs}", "Auto Convervation Generated", "")
|
185 |
)
|
186 |
for qa in auto_qa_conversation:
|
187 |
-
st.session_state.messages.append((qa[0], qa[1], ""))
|
188 |
|
189 |
st.session_state.costing.append(
|
190 |
{
|
@@ -193,12 +214,16 @@ def auto_qa_chain_wrapper(inputs):
|
|
193 |
"cost": stats.total_cost,
|
194 |
}
|
195 |
)
|
196 |
-
return
|
197 |
-
|
|
|
|
|
|
|
|
|
198 |
)
|
199 |
|
200 |
|
201 |
-
def boot(command_center):
|
202 |
st.write("# Agent Zeta")
|
203 |
if "costing" not in st.session_state:
|
204 |
st.session_state.costing = []
|
@@ -208,34 +233,45 @@ def boot(command_center):
|
|
208 |
for message in st.session_state.messages:
|
209 |
st.chat_message("human").write(message[0])
|
210 |
st.chat_message("ai").write(
|
211 |
-
|
212 |
)
|
213 |
if query := st.chat_input():
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
pass
|
218 |
-
elif type(response) == tuple:
|
219 |
-
result, references = response
|
220 |
st.chat_message("ai").write(
|
221 |
-
|
222 |
)
|
223 |
-
|
224 |
-
st.
|
225 |
|
226 |
|
227 |
if __name__ == "__main__":
|
228 |
all_commands = [
|
229 |
-
("/
|
230 |
-
("/
|
231 |
-
("/
|
232 |
-
("/
|
233 |
-
("/
|
234 |
-
("/auto",
|
|
|
|
|
235 |
]
|
236 |
command_center = CommandCenter(
|
237 |
default_input_type=str,
|
238 |
-
default_function=
|
239 |
all_commands=all_commands,
|
240 |
)
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from langchain.callbacks import get_openai_callback
|
9 |
from langchain_openai import ChatOpenAI
|
10 |
import base64
|
11 |
+
from chat_chains import (
|
12 |
+
parse_model_response,
|
13 |
+
qa_chain,
|
14 |
+
format_docs,
|
15 |
+
parse_context_and_question,
|
16 |
+
ai_response_format,
|
17 |
+
)
|
18 |
+
from autoqa_chains import auto_qa_chain, auto_qa_output_parser
|
19 |
+
from chain_of_density import chain_of_density_chain
|
20 |
+
from custom_exceptions import InvalidArgumentError, InvalidCommandError
|
21 |
|
22 |
st.set_page_config(layout="wide")
|
23 |
os.environ["OPENAI_API_KEY"] = "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"
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
welcome_message = """
|
27 |
Hi I'm Agent Zeta, your AI assistant, dedicated to making your journey through machine learning research papers as insightful and interactive as possible. Whether you're diving into the latest studies or brushing up on foundational papers, I'm here to help navigate, discuss, and analyze content with you.
|
|
|
30 |
|
31 |
| Command | Description |
|
32 |
|---------|-------------|
|
33 |
+
| `/add-papers <list of urls>` | Upload and process documents for our conversation. |
|
34 |
+
| `/library` | View an index of processed documents to easily navigate your research. |
|
35 |
+
| `/session-expense` | Calculate the cost of our conversation, ensuring transparency in resource usage. |
|
36 |
+
| `/export` | Download conversation data for your records or further analysis. |
|
37 |
+
| `/auto-insight <document id>` | Automatically generate questions and answers for a document. |
|
38 |
+
| `/deep-dive [<list of document ids>] <query>` | Query the AI with a specific document context. |
|
39 |
+
| `/condense-summary <document id>` | Generate increasingly concise, entity-dense summaries of a document. |
|
40 |
+
|
41 |
|
42 |
<br>
|
43 |
|
44 |
Feel free to use these commands to enhance your research experience. Let's embark on this exciting journey of discovery together!
|
45 |
|
46 |
+
Use `/help-me` at any point of time to view this guide again.
|
47 |
"""
|
48 |
|
49 |
|
|
|
55 |
[snip.metadata["chunk_id"], snip.metadata["header"]] for snip in snippets
|
56 |
]
|
57 |
response = f"Uploaded and processed documents {inputs}"
|
58 |
+
st.session_state.messages.append((f"/add-papers {inputs}", response, "identity"))
|
59 |
st.session_state.documents = documents
|
60 |
+
return (response, "identity")
|
61 |
|
62 |
|
63 |
def index_documents_wrapper(inputs=None):
|
64 |
+
response = pd.DataFrame(st.session_state.index, columns=["id", "reference"])
|
65 |
+
st.session_state.messages.append(("/library", response, "dataframe"))
|
66 |
+
return (response, "dataframe")
|
|
|
|
|
67 |
|
68 |
|
69 |
def calculate_cost_wrapper(inputs=None):
|
70 |
try:
|
71 |
stats_df = pd.DataFrame(st.session_state.costing)
|
72 |
stats_df.loc["total"] = stats_df.sum()
|
73 |
+
response = stats_df
|
74 |
except ValueError:
|
75 |
response = "No cost incurred yet"
|
76 |
+
st.session_state.messages.append(("/session-expense", response, "dataframe"))
|
77 |
+
return (response, "dataframe")
|
78 |
|
79 |
|
80 |
def download_conversation_wrapper(inputs=None):
|
|
|
89 |
st.session_state.index if "index" in st.session_state else []
|
90 |
),
|
91 |
"conversation": [
|
92 |
+
{"human": message[0], "ai": jsonify_functions[message[2]](message[1])}
|
93 |
for message in st.session_state.messages
|
94 |
],
|
95 |
"costing": (
|
|
|
106 |
}
|
107 |
)
|
108 |
conversation_data = base64.b64encode(conversation_data.encode()).decode()
|
109 |
+
st.session_state.messages.append(
|
110 |
+
("/export", "Conversation data downloaded", "identity")
|
111 |
+
)
|
112 |
+
return (
|
113 |
+
f'<a href="data:text/csv;base64,{conversation_data}" download="conversation_data.json">Download Conversation</a>',
|
114 |
+
"identity",
|
115 |
+
)
|
116 |
|
117 |
|
118 |
+
def query_llm(inputs, relevant_docs):
|
|
|
|
|
|
|
|
|
|
|
119 |
with get_openai_callback() as cb:
|
120 |
+
response = (
|
121 |
+
qa_chain(ChatOpenAI(model="gpt-4-0125-preview", temperature=0))
|
122 |
+
.invoke({"context": format_docs(relevant_docs), "question": inputs})
|
123 |
+
.content
|
124 |
+
)
|
|
|
|
|
|
|
125 |
stats = cb
|
126 |
response = parse_model_response(response)
|
127 |
answer = response["answer"]
|
|
|
133 |
f"[{ref}]"
|
134 |
for ref in sorted(
|
135 |
[ref.metadata["chunk_id"] for ref in relevant_docs],
|
|
|
136 |
)
|
137 |
]
|
138 |
),
|
|
|
140 |
}
|
141 |
)
|
142 |
|
143 |
+
st.session_state.messages.append(
|
144 |
+
(inputs, {"answer": answer, "citations": citations}, "reponse_with_citations")
|
145 |
+
)
|
146 |
+
st.session_state.costing.append(
|
147 |
+
{
|
148 |
+
"prompt tokens": stats.prompt_tokens,
|
149 |
+
"completion tokens": stats.completion_tokens,
|
150 |
+
"cost": stats.total_cost,
|
151 |
+
}
|
152 |
+
)
|
153 |
+
return ({"answer": answer, "citations": citations}, "reponse_with_citations")
|
154 |
+
|
155 |
+
|
156 |
+
def rag_llm_wrapper(inputs):
|
157 |
+
retriever = st.session_state.retriever
|
158 |
+
relevant_docs = retriever.get_relevant_documents(inputs)
|
159 |
+
return query_llm(inputs, relevant_docs)
|
160 |
+
|
161 |
+
|
162 |
+
def query_llm_wrapper(inputs):
|
163 |
+
context, question = parse_context_and_question(inputs)
|
164 |
+
relevant_docs = [st.session_state.documents[c] for c in context]
|
165 |
+
print(context, question)
|
166 |
+
return query_llm(question, relevant_docs)
|
167 |
+
|
168 |
+
|
169 |
+
def chain_of_density_wrapper(inputs):
|
170 |
+
if inputs == "":
|
171 |
+
raise InvalidArgumentError("Please provide a document id")
|
172 |
+
document = st.session_state.documents[inputs].page_content
|
173 |
+
llm = ChatOpenAI(model="gpt-4-turbo-preview", temperature=0)
|
174 |
+
with get_openai_callback() as cb:
|
175 |
+
summary = chain_of_density_chain(llm).invoke({"paper": document})
|
176 |
+
stats = cb
|
177 |
+
st.session_state.messages.append(("/condense-summary", summary, "identity"))
|
178 |
st.session_state.costing.append(
|
179 |
{
|
180 |
"prompt tokens": stats.prompt_tokens,
|
|
|
182 |
"cost": stats.total_cost,
|
183 |
}
|
184 |
)
|
185 |
+
return (summary, "identity")
|
186 |
|
187 |
|
188 |
def auto_qa_chain_wrapper(inputs):
|
189 |
+
if inputs == "":
|
190 |
+
raise InvalidArgumentError("Please provide a document id")
|
191 |
+
document = st.session_state.documents[inputs].page_content
|
192 |
llm = ChatOpenAI(model="gpt-4-turbo-preview", temperature=0)
|
193 |
auto_qa_conversation = []
|
194 |
with get_openai_callback() as cb:
|
|
|
197 |
"questions"
|
198 |
]
|
199 |
auto_qa_conversation = [
|
200 |
+
(f'/auto {qa["question"]}', qa["answer"], "identity")
|
201 |
for qa in auto_qa_response_parsed
|
202 |
]
|
203 |
stats = cb
|
204 |
st.session_state.messages.append(
|
205 |
+
(f"/auto-insight {inputs}", "Auto Convervation Generated", "identity")
|
206 |
)
|
207 |
for qa in auto_qa_conversation:
|
208 |
+
st.session_state.messages.append((qa[0], qa[1], "identity"))
|
209 |
|
210 |
st.session_state.costing.append(
|
211 |
{
|
|
|
214 |
"cost": stats.total_cost,
|
215 |
}
|
216 |
)
|
217 |
+
return (
|
218 |
+
"\n\n".join(
|
219 |
+
f"Q: {qa['question']}\n\nA: {qa['answer']}"
|
220 |
+
for qa in auto_qa_response_parsed
|
221 |
+
),
|
222 |
+
"identity",
|
223 |
)
|
224 |
|
225 |
|
226 |
+
def boot(command_center, formating_functions):
|
227 |
st.write("# Agent Zeta")
|
228 |
if "costing" not in st.session_state:
|
229 |
st.session_state.costing = []
|
|
|
233 |
for message in st.session_state.messages:
|
234 |
st.chat_message("human").write(message[0])
|
235 |
st.chat_message("ai").write(
|
236 |
+
formating_functions[message[2]](message[1]), unsafe_allow_html=True
|
237 |
)
|
238 |
if query := st.chat_input():
|
239 |
+
try:
|
240 |
+
st.chat_message("human").write(query)
|
241 |
+
response, format_fn_name = command_center.execute_command(query)
|
|
|
|
|
|
|
242 |
st.chat_message("ai").write(
|
243 |
+
formating_functions[format_fn_name](response), unsafe_allow_html=True
|
244 |
)
|
245 |
+
except (InvalidArgumentError, InvalidCommandError) as e:
|
246 |
+
st.error(e)
|
247 |
|
248 |
|
249 |
if __name__ == "__main__":
|
250 |
all_commands = [
|
251 |
+
("/add-papers", list, process_documents_wrapper),
|
252 |
+
("/library", None, index_documents_wrapper),
|
253 |
+
("/session-expense", None, calculate_cost_wrapper),
|
254 |
+
("/export", None, download_conversation_wrapper),
|
255 |
+
("/help-me", None, lambda x: (welcome_message, "identity")),
|
256 |
+
("/auto-insight", str, auto_qa_chain_wrapper),
|
257 |
+
("/deep-dive", str, query_llm_wrapper),
|
258 |
+
("/condense-summary", str, chain_of_density_wrapper),
|
259 |
]
|
260 |
command_center = CommandCenter(
|
261 |
default_input_type=str,
|
262 |
+
default_function=rag_llm_wrapper,
|
263 |
all_commands=all_commands,
|
264 |
)
|
265 |
+
formating_functions = {
|
266 |
+
"identity": lambda x: x,
|
267 |
+
"dataframe": lambda x: x,
|
268 |
+
"reponse_with_citations": lambda x: ai_response_format(
|
269 |
+
x["answer"], x["citations"]
|
270 |
+
),
|
271 |
+
}
|
272 |
+
jsonify_functions = {
|
273 |
+
"identity": lambda x: x,
|
274 |
+
"dataframe": lambda x: x.to_dict(orient="records"),
|
275 |
+
"reponse_with_citations": lambda x: x,
|
276 |
+
}
|
277 |
+
boot(command_center, formating_functions)
|
chain_of_density.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.output_parsers import JsonOutputParser
|
2 |
+
from langchain_core.prompts import PromptTemplate
|
3 |
+
|
4 |
+
chain_of_density_prompt_template = """
|
5 |
+
Research Paper: {paper}
|
6 |
+
|
7 |
+
You will generate increasingly concise, entity-dense summaries of the above research paper.
|
8 |
+
|
9 |
+
Repeat the following 2 steps 10 times.
|
10 |
+
|
11 |
+
Step 1. Identify 1-3 informative Entities ('; ' delimited) from the research paper that are missing from the previously generated summary. These entities should be key components such as research questions, methodologies, findings, theoretical contributions, or implications.
|
12 |
+
Step 2. Write a new, denser summary of identical length which covers every entity and detail from the previous summary plus the Missing Entities.
|
13 |
+
|
14 |
+
A Missing Entity is:
|
15 |
+
- Relevant: critical to understanding the paper’s contribution.
|
16 |
+
- Specific: descriptive yet concise (5 words or fewer).
|
17 |
+
- Novel: not included in the previous summary.
|
18 |
+
- Faithful: accurately represented in the research paper.
|
19 |
+
- Anywhere: can be found anywhere in the research paper.
|
20 |
+
|
21 |
+
Guidelines:
|
22 |
+
- The first summary should be long (4-5 sentences, ~100 words) yet focus on general information about the research paper, including its broad topic and objectives, without going into detail.
|
23 |
+
- Avoid using verbose language and fillers (e.g., 'This research paper discusses') to reach the word count.
|
24 |
+
- Strive for efficiency in word use: rewrite the previous summary to improve readability and make space for additional entities.
|
25 |
+
- Employ strategies such as fusion (combining entities), compression (shortening descriptions), and removal of uninformative phrases to make space for new entities.
|
26 |
+
- The summaries should evolve to be highly dense and concise yet remain self-contained, meaning they can be understood without reading the full paper.
|
27 |
+
- Missing entities should be integrated seamlessly into the new summary.
|
28 |
+
- Never omit entities from previous summaries. If space is a challenge, incorporate fewer new entities but maintain the same word count.
|
29 |
+
|
30 |
+
Remember, use the exact same number of words for each summary.
|
31 |
+
|
32 |
+
The JSON output should be a list (length 10) of dictionaries. Each dictionary must have two keys: 'missing_entities', listing the 1-3 entities added in each round; and 'denser_summary', presenting the new summary that integrates these entities without increasing the length.
|
33 |
+
"""
|
34 |
+
|
35 |
+
chain_of_density_output_parser = JsonOutputParser()
|
36 |
+
chain_of_density_prompt = PromptTemplate(
|
37 |
+
template=chain_of_density_prompt_template,
|
38 |
+
input_variables=["paper"],
|
39 |
+
)
|
40 |
+
chain_of_density_chain = (
|
41 |
+
lambda model: chain_of_density_prompt | model | chain_of_density_output_parser
|
42 |
+
)
|
chat_chains.py
CHANGED
@@ -1,22 +1,8 @@
|
|
1 |
-
from langchain_core.prompts import ChatPromptTemplate
|
2 |
-
from langchain_core.output_parsers import StrOutputParser
|
3 |
from langchain_core.runnables import RunnablePassthrough
|
4 |
import xml.etree.ElementTree as ET
|
5 |
import re
|
6 |
|
7 |
-
contextualize_q_system_prompt = """Given a chat history and the latest user question \
|
8 |
-
which might reference context in the chat history, formulate a standalone question \
|
9 |
-
which can be understood without the chat history. Do NOT answer the question, \
|
10 |
-
just reformulate it if needed and otherwise return it as is."""
|
11 |
-
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
12 |
-
[
|
13 |
-
("system", contextualize_q_system_prompt),
|
14 |
-
MessagesPlaceholder(variable_name="chat_history"),
|
15 |
-
("human", "{question}"),
|
16 |
-
]
|
17 |
-
)
|
18 |
-
contextualize_q_chain = lambda llm: contextualize_q_prompt | llm | StrOutputParser()
|
19 |
-
|
20 |
qa_system_prompt = """As Zeta, your mission is to assist users in navigating the vast sea of machine learning research with ease and insight. When responding to inquiries, adhere to the following guidelines to ensure the utmost accuracy and utility:
|
21 |
|
22 |
Contextual Understanding: When presented with a question, apply your understanding of machine learning concepts to interpret the context provided accurately. Utilize this context to guide your search for answers within the specified research papers.
|
@@ -46,7 +32,7 @@ By following these guidelines, you ensure that users receive valuable, accurate,
|
|
46 |
qa_prompt = ChatPromptTemplate.from_messages(
|
47 |
[
|
48 |
("system", qa_system_prompt),
|
49 |
-
MessagesPlaceholder(variable_name="chat_history"),
|
50 |
("human", "{question}"),
|
51 |
]
|
52 |
)
|
@@ -54,21 +40,19 @@ qa_prompt = ChatPromptTemplate.from_messages(
|
|
54 |
|
55 |
def format_docs(docs):
|
56 |
return "\n\n".join(
|
57 |
-
f"{doc.metadata['chunk_id']}: {doc.page_content}"
|
|
|
58 |
)
|
59 |
|
60 |
|
61 |
-
def contextualized_question(input: dict):
|
62 |
-
if input.get("chat_history"):
|
63 |
-
return contextualize_q_chain
|
64 |
-
else:
|
65 |
-
return input["question"]
|
66 |
-
|
67 |
-
|
68 |
rag_chain = lambda retriever, llm: (
|
69 |
-
RunnablePassthrough
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
| qa_prompt
|
73 |
| llm
|
74 |
)
|
@@ -105,3 +89,25 @@ def parse_model_response(input_string):
|
|
105 |
parsed_data["answer"] = "".join(outside_text_parts)
|
106 |
|
107 |
return parsed_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import ChatPromptTemplate
|
|
|
2 |
from langchain_core.runnables import RunnablePassthrough
|
3 |
import xml.etree.ElementTree as ET
|
4 |
import re
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
qa_system_prompt = """As Zeta, your mission is to assist users in navigating the vast sea of machine learning research with ease and insight. When responding to inquiries, adhere to the following guidelines to ensure the utmost accuracy and utility:
|
7 |
|
8 |
Contextual Understanding: When presented with a question, apply your understanding of machine learning concepts to interpret the context provided accurately. Utilize this context to guide your search for answers within the specified research papers.
|
|
|
32 |
qa_prompt = ChatPromptTemplate.from_messages(
|
33 |
[
|
34 |
("system", qa_system_prompt),
|
35 |
+
# MessagesPlaceholder(variable_name="chat_history"),
|
36 |
("human", "{question}"),
|
37 |
]
|
38 |
)
|
|
|
40 |
|
41 |
def format_docs(docs):
|
42 |
return "\n\n".join(
|
43 |
+
f"{doc.metadata['chunk_id']}: {doc.page_content}" if type(doc) != str else doc
|
44 |
+
for doc in docs
|
45 |
)
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
rag_chain = lambda retriever, llm: (
|
49 |
+
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
50 |
+
| qa_prompt
|
51 |
+
| llm
|
52 |
+
)
|
53 |
+
|
54 |
+
qa_chain = lambda llm: (
|
55 |
+
{"context": RunnablePassthrough(), "question": RunnablePassthrough()}
|
56 |
| qa_prompt
|
57 |
| llm
|
58 |
)
|
|
|
89 |
parsed_data["answer"] = "".join(outside_text_parts)
|
90 |
|
91 |
return parsed_data
|
92 |
+
|
93 |
+
|
94 |
+
def parse_context_and_question(inputs):
|
95 |
+
pattern = r"\[(.*?)\]"
|
96 |
+
match = re.search(pattern, inputs)
|
97 |
+
if match:
|
98 |
+
context = match.group(1)
|
99 |
+
context = [c.strip() for c in context.split()]
|
100 |
+
question = inputs[: match.start()] + inputs[match.end() :]
|
101 |
+
return context, question
|
102 |
+
else:
|
103 |
+
return "", inputs
|
104 |
+
|
105 |
+
|
106 |
+
format_citations = lambda citations: "\n\n".join(
|
107 |
+
[f"{citation['quote']} ... [{citation['source_id']}]" for citation in citations]
|
108 |
+
)
|
109 |
+
ai_response_format = lambda message, references: (
|
110 |
+
f"{message}\n\n---\n\n{format_citations(references)}"
|
111 |
+
if references != ""
|
112 |
+
else message
|
113 |
+
)
|
command_center.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
class CommandCenter:
|
2 |
def __init__(self, default_input_type, default_function=None, all_commands=None):
|
3 |
self.commands = {}
|
@@ -20,6 +23,9 @@ class CommandCenter:
|
|
20 |
command = inputs[0]
|
21 |
argument = inputs[1:]
|
22 |
|
|
|
|
|
|
|
23 |
# type casting the arguments
|
24 |
if self.commands[command]["input_type"] == str:
|
25 |
argument = " ".join(argument)
|
|
|
1 |
+
from custom_exceptions import InvalidCommandError
|
2 |
+
|
3 |
+
|
4 |
class CommandCenter:
|
5 |
def __init__(self, default_input_type, default_function=None, all_commands=None):
|
6 |
self.commands = {}
|
|
|
23 |
command = inputs[0]
|
24 |
argument = inputs[1:]
|
25 |
|
26 |
+
if command not in self.commands:
|
27 |
+
raise InvalidCommandError("Invalid command")
|
28 |
+
|
29 |
# type casting the arguments
|
30 |
if self.commands[command]["input_type"] == str:
|
31 |
argument = " ".join(argument)
|
custom_exceptions.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class InvalidCommandError(Exception):
|
2 |
+
pass
|
3 |
+
|
4 |
+
|
5 |
+
class InvalidArgumentError(Exception):
|
6 |
+
pass
|
process_documents.py
CHANGED
@@ -10,14 +10,25 @@ deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip()
|
|
10 |
|
11 |
def process_documents(urls):
|
12 |
snippets = []
|
13 |
-
documents =
|
14 |
for source_id, url in enumerate(urls):
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
return snippets, documents
|
22 |
|
23 |
|
@@ -30,7 +41,7 @@ def process_web(url, source_id):
|
|
30 |
"header": data.metadata["title"],
|
31 |
"source_url": url,
|
32 |
"source_type": "web",
|
33 |
-
"chunk_id":
|
34 |
"source_id": source_id,
|
35 |
},
|
36 |
)
|
@@ -54,7 +65,7 @@ def process_pdf(url, source_id):
|
|
54 |
"header": " ".join(snip[1]["header_text"].split()[:10]),
|
55 |
"source_url": url,
|
56 |
"source_type": "pdf",
|
57 |
-
"chunk_id": f"{source_id}_{i}",
|
58 |
"source_id": source_id,
|
59 |
},
|
60 |
)
|
|
|
10 |
|
11 |
def process_documents(urls):
|
12 |
snippets = []
|
13 |
+
documents = {}
|
14 |
for source_id, url in enumerate(urls):
|
15 |
+
snippet = (
|
16 |
+
process_pdf(url, source_id)
|
17 |
+
if url.endswith(".pdf")
|
18 |
+
else process_web(url, source_id)
|
19 |
+
)
|
20 |
+
snippets.extend(snippet)
|
21 |
+
documents[str(source_id)] = Document(
|
22 |
+
page_content="\n".join([snip.page_content for snip in snippet]),
|
23 |
+
metadata={
|
24 |
+
"source_url": url,
|
25 |
+
"source_type": "pdf" if url.endswith(".pdf") else "web",
|
26 |
+
"source_id": source_id,
|
27 |
+
"chunk_id": source_id,
|
28 |
+
},
|
29 |
+
)
|
30 |
+
for snip in snippet:
|
31 |
+
documents[snip.metadata["chunk_id"]] = snip
|
32 |
return snippets, documents
|
33 |
|
34 |
|
|
|
41 |
"header": data.metadata["title"],
|
42 |
"source_url": url,
|
43 |
"source_type": "web",
|
44 |
+
"chunk_id": source_id,
|
45 |
"source_id": source_id,
|
46 |
},
|
47 |
)
|
|
|
65 |
"header": " ".join(snip[1]["header_text"].split()[:10]),
|
66 |
"source_url": url,
|
67 |
"source_type": "pdf",
|
68 |
+
"chunk_id": f"{source_id}_{i:02d}",
|
69 |
"source_id": source_id,
|
70 |
},
|
71 |
)
|