Sbnos commited on
Commit
1bc894f
1 Parent(s): 180cbdd
Files changed (1) hide show
  1. app.py +199 -88
app.py CHANGED
@@ -1,159 +1,270 @@
1
  import streamlit as st
2
  import os
3
- import asyncio
4
- from langchain_community.vectorstores import Chroma
5
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
6
  from langchain_together import Together
7
  from langchain import hub
8
  from operator import itemgetter
 
9
  from langchain.schema import format_document
10
- from langchain.prompts import ChatPromptTemplate, PromptTemplate
11
- from langchain_community.chat_message_histories import StreamlitChatMessageHistory
 
 
 
12
  from langchain.memory import ConversationBufferMemory
13
- from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough
 
 
 
 
14
 
15
  # Load the embedding function
16
  model_name = "BAAI/bge-base-en"
17
- encode_kwargs = {'normalize_embeddings': True}
 
18
  embedding_function = HuggingFaceBgeEmbeddings(
19
  model_name=model_name,
20
  encode_kwargs=encode_kwargs
21
  )
22
 
23
- # Initialize the LLMs
 
 
 
 
 
 
 
24
  llm = Together(
25
  model="mistralai/Mixtral-8x22B-Instruct-v0.1",
26
  temperature=0.2,
27
  top_k=12,
28
- together_api_key=os.environ['pilotikval'],
29
- max_tokens=200
30
  )
31
 
 
32
  llmc = Together(
33
  model="mistralai/Mixtral-8x22B-Instruct-v0.1",
34
  temperature=0.2,
35
  top_k=3,
36
- together_api_key=os.environ['pilotikval'],
37
- max_tokens=200
38
  )
39
 
40
- # Memory setup
41
  msgs = StreamlitChatMessageHistory(key="langchain_messages")
42
  memory = ConversationBufferMemory(chat_memory=msgs)
43
 
44
- # Define the prompt templates
45
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(
46
- """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question.
47
- Chat History:
48
- {chat_history}
49
- Follow Up Input: {question}
50
- Standalone question:"""
51
- )
52
 
53
- ANSWER_PROMPT = ChatPromptTemplate.from_template(
54
- """You are helping a doctor. Answer based on the provided context:
55
- {context}
56
- Question: {question}"""
57
- )
58
 
59
- # Function to combine documents
60
- def _combine_documents(docs, document_prompt=PromptTemplate.from_template("{page_content}"), document_separator="\n\n"):
61
- doc_strings = [format_document(doc, document_prompt) for doc in docs]
62
- return document_separator.join(doc_strings)
 
 
 
 
 
63
 
64
- # Function to store chat history
65
  chistory = []
66
 
67
  def store_chat_history(role: str, content: str):
 
68
  chistory.append({"role": role, "content": content})
69
 
70
- # Define the chain using LCEL
71
- def create_conversational_qa_chain(retriever, condense_llm, answer_llm):
72
- condense_question_chain = (
73
- RunnableLambda(lambda x: {"chat_history": chistory, "question": x['question']})
74
- | CONDENSE_QUESTION_PROMPT
75
- | RunnableLambda(lambda x: {"standalone_question": x})
76
- )
77
 
78
- retrieval_chain = (
79
- RunnableLambda(lambda x: {"standalone_question": x['standalone_question']})
80
- | retriever
81
- | RunnableLambda(lambda x: {"context": _combine_documents(x)})
82
- )
83
-
84
- answer_chain = ANSWER_PROMPT | answer_llm
85
 
86
- return RunnableParallel(
87
- condense_question=condense_question_chain,
88
- retrieve=retrieval_chain,
89
- generate_answer=answer_chain
90
- )
91
 
92
- # Asynchronous function to handle streaming responses
93
- async def stream_response(conversational_qa_chain, prompts2, chistory):
94
- response_chunks = []
95
- async for chunk in conversational_qa_chain.astream(
96
- {
97
- "question": prompts2,
98
- "chat_history": chistory,
99
- }
100
- ):
101
- response_chunks.append(chunk['generate_answer'])
102
- st.write("".join(response_chunks))
103
- return "".join(response_chunks)
104
 
105
- # Define the Streamlit app
106
- def app():
107
  with st.sidebar:
 
108
  st.title("dochatter")
 
109
  option = st.selectbox(
110
  'Which retriever would you like to use?',
111
  ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
112
  )
 
 
 
 
 
 
113
 
114
- # Define retrievers based on option
115
- persist_directory = {
116
- 'General Medicine': "./oxfordmedbookdir/",
117
- 'RespiratoryFishman': "./respfishmandbcud/",
118
- 'RespiratoryMurray': "./respmurray/",
119
- 'MedMRCP2': "./medmrcp2store/",
120
- 'OldMedicine': "./mrcpchromadb/"
121
- }.get(option, "./mrcpchromadb/")
122
-
123
- collection_name = {
124
- 'General Medicine': "oxfordmed",
125
- 'RespiratoryFishman': "fishmannotescud",
126
- 'RespiratoryMurray': "respmurraynotes",
127
- 'MedMRCP2': "medmrcp2notes",
128
- 'OldMedicine': "mrcppassmednotes"
129
- }.get(option, "mrcppassmednotes")
130
-
131
- vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name=collection_name)
132
- retriever = vectordb.as_retriever(search_kwargs={"k": 5})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- if "messages" not in st.session_state:
 
 
 
 
 
135
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  st.header("Ask Away!")
 
138
  for message in st.session_state.messages:
139
  with st.chat_message(message["role"]):
140
  st.write(message["content"])
141
  store_chat_history(message["role"], message["content"])
 
 
 
 
 
142
 
143
  prompts2 = st.chat_input("Say something")
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if prompts2:
146
  st.session_state.messages.append({"role": "user", "content": prompts2})
147
  with st.chat_message("user"):
148
  st.write(prompts2)
 
 
149
 
150
  if st.session_state.messages[-1]["role"] != "assistant":
151
- conversational_qa_chain = create_conversational_qa_chain(retriever, llmc, llm)
152
  with st.chat_message("assistant"):
153
  with st.spinner("Thinking..."):
154
- final_response = asyncio.run(stream_response(conversational_qa_chain, prompts2, chistory))
155
- message = {"role": "assistant", "content": final_response}
 
 
 
 
 
 
156
  st.session_state.messages.append(message)
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  if __name__ == '__main__':
159
- app()
 
1
  import streamlit as st
2
  import os
3
+ from langchain.vectorstores import Chroma
4
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
 
5
  from langchain_together import Together
6
  from langchain import hub
7
  from operator import itemgetter
8
+ from langchain.schema.runnable import RunnableParallel
9
  from langchain.schema import format_document
10
+ from typing import List, Tuple
11
+ from langchain.chains import LLMChain
12
+ from langchain.chains import RetrievalQA
13
+ from langchain.schema.output_parser import StrOutputParser
14
+ from langchain.memory import StreamlitChatMessageHistory
15
  from langchain.memory import ConversationBufferMemory
16
+ from langchain.chains import ConversationalRetrievalChain
17
+ from langchain.memory import ConversationSummaryMemory
18
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
19
+ from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
20
+
21
 
22
  # Load the embedding function
23
  model_name = "BAAI/bge-base-en"
24
+ encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
25
+
26
  embedding_function = HuggingFaceBgeEmbeddings(
27
  model_name=model_name,
28
  encode_kwargs=encode_kwargs
29
  )
30
 
31
+ # Load the ChromaDB vector store
32
+ # persist_directory="./mrcpchromadb/"
33
+ # vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function,collection_name="mrcppassmednotes")
34
+
35
+
36
+
37
+
38
+ # Load the LLM
39
  llm = Together(
40
  model="mistralai/Mixtral-8x22B-Instruct-v0.1",
41
  temperature=0.2,
42
  top_k=12,
43
+ together_api_key=os.environ['pilotikval']
 
44
  )
45
 
46
+ # Load the summarizeLLM
47
  llmc = Together(
48
  model="mistralai/Mixtral-8x22B-Instruct-v0.1",
49
  temperature=0.2,
50
  top_k=3,
51
+ together_api_key=os.environ['pilotikval']
 
52
  )
53
 
 
54
  msgs = StreamlitChatMessageHistory(key="langchain_messages")
55
  memory = ConversationBufferMemory(chat_memory=msgs)
56
 
 
 
 
 
 
 
 
 
57
 
 
 
 
 
 
58
 
59
+ DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
60
+
61
+ def _combine_documents(
62
+ docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
63
+ ):
64
+ doc_strings = [format_document(doc, document_prompt) for doc in docs]
65
+ return document_separator.join(doc_strings)
66
+
67
+
68
 
 
69
  chistory = []
70
 
71
  def store_chat_history(role: str, content: str):
72
+ # Append the new message to the chat history
73
  chistory.append({"role": role, "content": content})
74
 
 
 
 
 
 
 
 
75
 
76
+ # Define the Streamlit app
77
+ def app():
 
 
 
 
 
78
 
 
 
 
 
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
 
81
  with st.sidebar:
82
+
83
  st.title("dochatter")
84
+ # Create a dropdown selection box
85
  option = st.selectbox(
86
  'Which retriever would you like to use?',
87
  ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
88
  )
89
+ # Depending on the selected option, choose the appropriate retriever
90
+ if option == 'RespiratoryFishman':
91
+ persist_directory="./respfishmandbcud/"
92
+ vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function,collection_name="fishmannotescud")
93
+ retriever = vectordb.as_retriever(search_kwargs={"k": 5})
94
+ retriever = retriever # replace with your actual retriever
95
 
96
+ if option == 'RespiratoryMurray':
97
+ persist_directory="./respmurray/"
98
+ vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function,collection_name="respmurraynotes")
99
+ retriever = vectordb.as_retriever(search_kwargs={"k": 5})
100
+ retriever = retriever
101
+
102
+ if option == 'MedMRCP2':
103
+ persist_directory="./medmrcp2store/"
104
+ vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function,collection_name="medmrcp2notes")
105
+ retriever = vectordb.as_retriever(search_kwargs={"k": 5})
106
+ retriever = retriever
107
+
108
+ if option == 'General Medicine':
109
+ persist_directory="./oxfordmedbookdir/"
110
+ vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function,collection_name="oxfordmed")
111
+ retriever = vectordb.as_retriever(search_kwargs={"k": 7})
112
+ retriever = retriever
113
+
114
+ else:
115
+ persist_directory="./mrcpchromadb/"
116
+ vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function,collection_name="mrcppassmednotes")
117
+ retriever = vectordb.as_retriever(search_kwargs={"k": 5})
118
+ retriever = retriever # replace with your actual retriever
119
+ retriever = retriever # replace with your actual retriever
120
+
121
+ #template = """You are an AI chatbot having a conversation with a human. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
122
+ #{context}
123
+ #{history}
124
+ #Human: {human_input}
125
+ #AI: """
126
+ #prompt = PromptTemplate(input_variables=["history", "question"], template=template)
127
+ #template = st.text_area("Template", value=template, height=180)
128
+ #prompt2 = ChatPromptTemplate.from_template(template)
129
 
130
+
131
+
132
+
133
+ # Session State
134
+ # Store LLM generated responses
135
+ if "messages" not in st.session_state.keys():
136
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
137
 
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+ ## Retry lets go
148
+
149
+ _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question which contains the themes of the conversation. Do not write the question. Do not write the answer.
150
+ Chat History:
151
+ {chat_history}
152
+ Follow Up Input: {question}
153
+ Standalone question:"""
154
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
155
+
156
+ template = """You are helping a doctor. Answer with what you know from the context provided. Please be as detailed and thorough. Answer the question based on the following context:
157
+ {context}
158
+ Question: {question}
159
+ """
160
+ ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
161
+
162
+
163
+ _inputs = RunnableParallel(
164
+ standalone_question=RunnablePassthrough.assign(
165
+ chat_history=lambda x: chistory
166
+ )
167
+ | CONDENSE_QUESTION_PROMPT
168
+ | llmc
169
+ | StrOutputParser(),
170
+ )
171
+ _context = {
172
+ "context": itemgetter("standalone_question") | retriever | _combine_documents,
173
+ "question": lambda x: x["standalone_question"],
174
+ }
175
+ conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | llm
176
+
177
+
178
+
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
  st.header("Ask Away!")
188
+ # Display the messages
189
  for message in st.session_state.messages:
190
  with st.chat_message(message["role"]):
191
  st.write(message["content"])
192
  store_chat_history(message["role"], message["content"])
193
+
194
+ # prompt = hub.pull("rlm/rag-prompt")
195
+
196
+
197
+
198
 
199
  prompts2 = st.chat_input("Say something")
200
 
201
+ # Implement using different book sources, if statements
202
+
203
+
204
+
205
+
206
+
207
+
208
+
209
+
210
+
211
+
212
+
213
+
214
+
215
+
216
  if prompts2:
217
  st.session_state.messages.append({"role": "user", "content": prompts2})
218
  with st.chat_message("user"):
219
  st.write(prompts2)
220
+
221
+
222
 
223
  if st.session_state.messages[-1]["role"] != "assistant":
 
224
  with st.chat_message("assistant"):
225
  with st.spinner("Thinking..."):
226
+ response = conversational_qa_chain.invoke(
227
+ {
228
+ "question": prompts2,
229
+ "chat_history": chistory,
230
+ }
231
+ )
232
+ st.write(response)
233
+ message = {"role": "assistant", "content": response}
234
  st.session_state.messages.append(message)
235
 
236
+
237
+
238
+
239
+
240
+
241
+
242
+
243
+
244
+
245
+
246
+
247
+ # Create a button to submit the question
248
+
249
+
250
+
251
+
252
+
253
+
254
+
255
+
256
+
257
+
258
+
259
+
260
+
261
+
262
+
263
+
264
+
265
+
266
+ # Initialize history
267
+ history = []
268
+
269
  if __name__ == '__main__':
270
+ app()