Ritvik19 commited on
Commit
6ae5e8b
1 Parent(s): 827774a
Files changed (3) hide show
  1. app.py +39 -3
  2. autoqa_chains.py +54 -0
  3. chat_chains.py +107 -0
app.py CHANGED
@@ -8,8 +8,9 @@ import json
8
  from langchain.callbacks import get_openai_callback
9
  from langchain_openai import ChatOpenAI
10
  import base64
11
- from chains import rag_chain, parse_model_response
12
  from langchain_core.messages import AIMessage, HumanMessage
 
13
 
14
  st.set_page_config(layout="wide")
15
  os.environ["OPENAI_API_KEY"] = "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"
@@ -41,10 +42,11 @@ Here's a quick guide to getting started with me:
41
 
42
  | Command | Description |
43
  |---------|-------------|
44
- | `/upload` | Upload and process documents for our conversation. |
45
  | `/index` | View an index of processed documents to easily navigate your research. |
46
  | `/cost` | Calculate the cost of our conversation, ensuring transparency in resource usage. |
47
  | `/download` | Download conversation data for your records or further analysis. |
 
48
 
49
  <br>
50
 
@@ -55,7 +57,7 @@ Use `/man` at any point of time to view this guide again.
55
 
56
 
57
  def process_documents_wrapper(inputs):
58
- snippets = process_documents(inputs)
59
  st.session_state.retriever = create_retriever(snippets)
60
  st.session_state.source_doc_urls = inputs
61
  st.session_state.index = [
@@ -63,6 +65,7 @@ def process_documents_wrapper(inputs):
63
  ]
64
  response = f"Uploaded and processed documents {inputs}"
65
  st.session_state.messages.append((f"/upload {inputs}", response, ""))
 
66
  return response
67
 
68
 
@@ -163,6 +166,38 @@ def query_llm_wrapper(inputs):
163
  return answer, citations
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def boot(command_center):
167
  st.write("# Agent Zeta")
168
  if "costing" not in st.session_state:
@@ -196,6 +231,7 @@ if __name__ == "__main__":
196
  ("/cost", None, calculate_cost_wrapper),
197
  ("/download", None, download_conversation_wrapper),
198
  ("/man", None, lambda x: welcome_message),
 
199
  ]
200
  command_center = CommandCenter(
201
  default_input_type=str,
 
8
  from langchain.callbacks import get_openai_callback
9
  from langchain_openai import ChatOpenAI
10
  import base64
11
+ from chat_chains import rag_chain, parse_model_response
12
  from langchain_core.messages import AIMessage, HumanMessage
13
+ from autoqa_chains import auto_qa_chain, followup_qa_chain, auto_qa_output_parser
14
 
15
  st.set_page_config(layout="wide")
16
  os.environ["OPENAI_API_KEY"] = "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"
 
42
 
43
  | Command | Description |
44
  |---------|-------------|
45
+ | `/upload` <list of urls> | Upload and process documents for our conversation. |
46
  | `/index` | View an index of processed documents to easily navigate your research. |
47
  | `/cost` | Calculate the cost of our conversation, ensuring transparency in resource usage. |
48
  | `/download` | Download conversation data for your records or further analysis. |
49
+ | `/auto` <document id> | Automatically generate questions and answers for a document. |
50
 
51
  <br>
52
 
 
57
 
58
 
59
  def process_documents_wrapper(inputs):
60
+ snippets, documents = process_documents(inputs)
61
  st.session_state.retriever = create_retriever(snippets)
62
  st.session_state.source_doc_urls = inputs
63
  st.session_state.index = [
 
65
  ]
66
  response = f"Uploaded and processed documents {inputs}"
67
  st.session_state.messages.append((f"/upload {inputs}", response, ""))
68
+ st.session_state.documents = documents
69
  return response
70
 
71
 
 
166
  return answer, citations
167
 
168
 
169
+ def auto_qa_chain_wrapper(inputs):
170
+ document = st.session_state.documents[inputs]
171
+ llm = ChatOpenAI(model="gpt-4-turbo-preview", temperature=0)
172
+ auto_qa_conversation = []
173
+ with get_openai_callback() as cb:
174
+ auto_qa_response = auto_qa_chain(llm).invoke({"paper": document})
175
+ auto_qa_response_parsed = auto_qa_output_parser.invoke(auto_qa_response)[
176
+ "questions"
177
+ ]
178
+ auto_qa_conversation = [
179
+ (f'/auto {qa["question"]}', qa["answer"], "")
180
+ for qa in auto_qa_response_parsed
181
+ ]
182
+ stats = cb
183
+ st.session_state.messages.append(
184
+ (f"/auto {inputs}", "Auto Convervation Generated", "")
185
+ )
186
+ for qa in auto_qa_conversation:
187
+ st.session_state.messages.append((qa[0], qa[1], ""))
188
+
189
+ st.session_state.costing.append(
190
+ {
191
+ "prompt tokens": stats.prompt_tokens,
192
+ "completion tokens": stats.completion_tokens,
193
+ "cost": stats.total_cost,
194
+ }
195
+ )
196
+ return "\n\n".join(
197
+ f"Q: {qa['question']}\n\nA: {qa['answer']}" for qa in auto_qa_response_parsed
198
+ )
199
+
200
+
201
  def boot(command_center):
202
  st.write("# Agent Zeta")
203
  if "costing" not in st.session_state:
 
231
  ("/cost", None, calculate_cost_wrapper),
232
  ("/download", None, download_conversation_wrapper),
233
  ("/man", None, lambda x: welcome_message),
234
+ ("/auto", int, auto_qa_chain_wrapper),
235
  ]
236
  command_center = CommandCenter(
237
  default_input_type=str,
autoqa_chains.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.pydantic_v1 import BaseModel, Field
2
+ from typing import List
3
+ from langchain_core.output_parsers import JsonOutputParser
4
+ from langchain_core.prompts import PromptTemplate
5
+
6
+
7
+ class QA(BaseModel):
8
+ question: str = Field(description="question")
9
+ answer: str = Field(description="answer")
10
+
11
+
12
+ class AutoQA(BaseModel):
13
+ questions: List[QA] = Field(description="list of question and answers")
14
+
15
+
16
+ qa_prompt_template = """
17
+ Come up with the 10 questions and answers that could be commonly asked by people about the following research paper.
18
+ The question and answers should capture the whole essence of the research paper
19
+ The answers should be a bit detailed and strictly based on the research paper.
20
+ Your response should be recorded in the following json format: {format_instructions}.
21
+
22
+ here is the research paper: ####{paper}####
23
+ """
24
+
25
+ auto_qa_output_parser = JsonOutputParser(pydantic_object=AutoQA)
26
+ qa_prompt = PromptTemplate(
27
+ template=qa_prompt_template,
28
+ input_variables=["paper"],
29
+ partial_variables={
30
+ "format_instructions": auto_qa_output_parser.get_format_instructions()
31
+ },
32
+ )
33
+ auto_qa_chain = lambda model: qa_prompt | model
34
+
35
+
36
+ followup_prompt_template = """
37
+ Question: {question}
38
+ Answer: {answer}
39
+ Based on the above question and answer and the research paper as your context, come up with a followup question and its answer.
40
+ The answer should be a bit detailed and strictly based on the research paper.
41
+ Your response should be recorded in the following json format: {format_instructions}.
42
+
43
+ here is the research paper: ####{paper}####
44
+ """
45
+
46
+ followup_prompt = PromptTemplate(
47
+ template=followup_prompt_template,
48
+ input_variables=["paper"],
49
+ partial_variables={
50
+ "format_instructions": auto_qa_output_parser.get_format_instructions()
51
+ },
52
+ )
53
+
54
+ followup_qa_chain = lambda model: qa_prompt | model
chat_chains.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
2
+ from langchain_core.output_parsers import StrOutputParser
3
+ from langchain_core.runnables import RunnablePassthrough
4
+ import xml.etree.ElementTree as ET
5
+ import re
6
+
7
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
8
+ which might reference context in the chat history, formulate a standalone question \
9
+ which can be understood without the chat history. Do NOT answer the question, \
10
+ just reformulate it if needed and otherwise return it as is."""
11
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
12
+ [
13
+ ("system", contextualize_q_system_prompt),
14
+ MessagesPlaceholder(variable_name="chat_history"),
15
+ ("human", "{question}"),
16
+ ]
17
+ )
18
+ contextualize_q_chain = lambda llm: contextualize_q_prompt | llm | StrOutputParser()
19
+
20
+ qa_system_prompt = """As Zeta, your mission is to assist users in navigating the vast sea of machine learning research with ease and insight. When responding to inquiries, adhere to the following guidelines to ensure the utmost accuracy and utility:
21
+
22
+ Contextual Understanding: When presented with a question, apply your understanding of machine learning concepts to interpret the context provided accurately. Utilize this context to guide your search for answers within the specified research papers.
23
+
24
+ Answer Provision: Always provide an answer that is directly supported by the research papers' content. If the information needed to answer the question is not available, clearly state, "I don't know."
25
+
26
+ Citation Requirement: For every answer given, include multiple citations from the research papers. A citation must include a direct quote from the paper that supports your answer, along with the identification (ID) of the paper. This ensures that all provided information can be traced back to its source, maintaining a high level of credibility and transparency.
27
+
28
+ Formatting Guidelines: Present your citations in the following structured format at the end of your answer to maintain clarity and consistency:
29
+
30
+
31
+ <citations>
32
+ <citation><source_id>[Source ID]</source_id><quote>[Direct quote from the source]</quote></citation>
33
+ ...
34
+ </citations>
35
+
36
+
37
+ Conflict Resolution: In cases where multiple sources offer conflicting information, evaluate the context, relevance, and credibility of each source to determine the most accurate answer. Explain your reasoning within the citation section to provide insight into your decision-making process.
38
+
39
+ User Engagement: Encourage user engagement by asking clarifying questions if the initial inquiry is ambiguous or lacks specific context. This helps in providing more targeted and relevant responses.
40
+
41
+ Continual Learning: Although you are not expected to generate new text or insights beyond the provided papers, be open to learning from new information as it becomes available to you through user interactions and queries.
42
+
43
+ By following these guidelines, you ensure that users receive valuable, accurate, and source-backed insights into their inquiries, making their exploration of machine learning research more productive and enlightening.
44
+
45
+ {context}"""
46
+ qa_prompt = ChatPromptTemplate.from_messages(
47
+ [
48
+ ("system", qa_system_prompt),
49
+ MessagesPlaceholder(variable_name="chat_history"),
50
+ ("human", "{question}"),
51
+ ]
52
+ )
53
+
54
+
55
+ def format_docs(docs):
56
+ return "\n\n".join(
57
+ f"{doc.metadata['chunk_id']}: {doc.page_content}" for doc in docs
58
+ )
59
+
60
+
61
+ def contextualized_question(input: dict):
62
+ if input.get("chat_history"):
63
+ return contextualize_q_chain
64
+ else:
65
+ return input["question"]
66
+
67
+
68
+ rag_chain = lambda retriever, llm: (
69
+ RunnablePassthrough.assign(
70
+ context=contextualized_question | retriever | format_docs
71
+ )
72
+ | qa_prompt
73
+ | llm
74
+ )
75
+
76
+
77
+ def parse_model_response(input_string):
78
+ parsed_data = {"answer": "", "citations": []}
79
+ xml_matches = re.findall(r"<citations>.*?</citations>", input_string, re.DOTALL)
80
+ if not xml_matches:
81
+ parsed_data["answer"] = input_string
82
+ return parsed_data
83
+
84
+ outside_text_parts = []
85
+ last_end_pos = 0
86
+
87
+ for xml_string in xml_matches:
88
+ match = re.search(re.escape(xml_string), input_string[last_end_pos:], re.DOTALL)
89
+
90
+ if match:
91
+ outside_text_parts.append(
92
+ input_string[last_end_pos : match.start() + last_end_pos]
93
+ )
94
+ last_end_pos += match.end()
95
+
96
+ root = ET.fromstring(xml_string)
97
+
98
+ for citation in root.findall("citation"):
99
+ source_id = citation.find("source_id").text
100
+ quote = citation.find("quote").text
101
+ parsed_data["citations"].append({"source_id": source_id, "quote": quote})
102
+
103
+ outside_text_parts.append(input_string[last_end_pos:])
104
+
105
+ parsed_data["answer"] = "".join(outside_text_parts)
106
+
107
+ return parsed_data