terapyon commited on
Commit
1f4ac35
โ€ข
1 Parent(s): 648f519

support RAG refs #5

Browse files
Files changed (2) hide show
  1. app.py +101 -5
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from datetime import datetime, date, timedelta
2
  from typing import Iterable
3
  import streamlit as st
@@ -6,6 +7,9 @@ from langchain.embeddings import HuggingFaceEmbeddings
6
  from langchain.vectorstores import Qdrant
7
  from qdrant_client import QdrantClient
8
  from qdrant_client.http.models import Filter, FieldCondition, MatchValue, Range
 
 
 
9
  from config import DB_CONFIG
10
  from model import Issue
11
 
@@ -23,7 +27,14 @@ def load_embeddings():
23
  return embeddings
24
 
25
 
 
 
 
 
 
 
26
  EMBEDDINGS = load_embeddings()
 
27
 
28
 
29
  def make_filter_obj(options: list[dict[str]]):
@@ -67,14 +78,46 @@ def get_similay(query: str, filter: Filter):
67
  return docs
68
 
69
 
70
- def main(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  query: str,
72
  repo_name: str,
73
  query_options: str,
74
  start_date: date,
75
  end_date: date,
76
  include_comments: bool,
77
- ) -> Iterable[tuple[Issue, float, str]]:
78
  options = [{"key": "metadata.repo_name", "value": repo_name}]
79
  if start_date is not None and end_date is not None:
80
  options.append(
@@ -96,6 +139,44 @@ def main(
96
  if query_options == "Empty":
97
  query_options = ""
98
  query_str = f"{query_options}{query}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  docs = get_similay(query_str, filter)
100
  for doc, score in docs:
101
  text = doc.page_content
@@ -153,13 +234,14 @@ with st.form("my_form"):
153
  )
154
  include_comments = st.checkbox(label="Include Issue comments", value=True)
155
 
156
- submitted = st.form_submit_button("Submit")
157
- if submitted:
 
158
  st.divider()
159
  st.header("Search Results")
160
  st.divider()
161
  with st.spinner("Searching..."):
162
- results = main(
163
  query, repo_name, query_options, start_date, end_date, include_comments
164
  )
165
  for issue, score, text in results:
@@ -182,3 +264,17 @@ with st.form("my_form"):
182
  st.write(f"{labels=}")
183
  # st.markdown(html, unsafe_allow_html=True)
184
  st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import time
2
  from datetime import datetime, date, timedelta
3
  from typing import Iterable
4
  import streamlit as st
 
7
  from langchain.vectorstores import Qdrant
8
  from qdrant_client import QdrantClient
9
  from qdrant_client.http.models import Filter, FieldCondition, MatchValue, Range
10
+ from langchain.chains import RetrievalQA
11
+ from openai.error import InvalidRequestError
12
+ from langchain.chat_models import ChatOpenAI
13
  from config import DB_CONFIG
14
  from model import Issue
15
 
 
27
  return embeddings
28
 
29
 
30
+ @st.cache_resource
31
+ def llm_model(model="gpt-3.5-turbo", temperature=0.2):
32
+ llm = ChatOpenAI(model=model, temperature=temperature)
33
+ return llm
34
+
35
+
36
  EMBEDDINGS = load_embeddings()
37
+ LLM = llm_model()
38
 
39
 
40
  def make_filter_obj(options: list[dict[str]]):
 
78
  return docs
79
 
80
 
81
+ def get_retrieval_qa(filter: Filter):
82
+ db_url, db_api_key, db_collection_name = DB_CONFIG
83
+ client = QdrantClient(url=db_url, api_key=db_api_key)
84
+ db = Qdrant(
85
+ client=client, collection_name=db_collection_name, embeddings=EMBEDDINGS
86
+ )
87
+ retriever = db.as_retriever(
88
+ search_kwargs={
89
+ "filter": filter,
90
+ }
91
+ )
92
+ result = RetrievalQA.from_chain_type(
93
+ llm=LLM,
94
+ chain_type="stuff",
95
+ retriever=retriever,
96
+ return_source_documents=True,
97
+ )
98
+ return result
99
+
100
+
101
+ def _get_related_url(metadata) -> Iterable[str]:
102
+ urls = set()
103
+ for m in metadata:
104
+ url = m["url"]
105
+ if url in urls:
106
+ continue
107
+ urls.add(url)
108
+ created_at = datetime.fromtimestamp(m["created_at"])
109
+ # print(m)
110
+ yield f'<p>URL: <a href="{url}">{url}</a> (created: {created_at:%Y-%m-%d})</p>'
111
+
112
+
113
+ def _get_query_str_filter(
114
  query: str,
115
  repo_name: str,
116
  query_options: str,
117
  start_date: date,
118
  end_date: date,
119
  include_comments: bool,
120
+ ) -> tuple[str, Filter]:
121
  options = [{"key": "metadata.repo_name", "value": repo_name}]
122
  if start_date is not None and end_date is not None:
123
  options.append(
 
139
  if query_options == "Empty":
140
  query_options = ""
141
  query_str = f"{query_options}{query}"
142
+ return query_str, filter
143
+
144
+
145
+ def run_qa(
146
+ query: str,
147
+ repo_name: str,
148
+ query_options: str,
149
+ start_date: date,
150
+ end_date: date,
151
+ include_comments: bool,
152
+ ) -> tuple[str, str]:
153
+ now = time()
154
+ query_str, filter = _get_query_str_filter(
155
+ query, repo_name, query_options, start_date, end_date, include_comments
156
+ )
157
+ qa = get_retrieval_qa(filter)
158
+ try:
159
+ result = qa(query_str)
160
+ except InvalidRequestError as e:
161
+ return "ๅ›ž็ญ”ใŒ่ฆ‹ใคใ‹ใ‚Šใพใ›ใ‚“ใงใ—ใŸใ€‚ๅˆฅใช่ณชๅ•ใ‚’ใ—ใฆใฟใฆใใ ใ•ใ„", str(e)
162
+ else:
163
+ metadata = [s.metadata for s in result["source_documents"]]
164
+ sec_html = f"<p>ๅฎŸ่กŒๆ™‚้–“: {(time() - now):.2f}็ง’</p>"
165
+ html = "<div>" + sec_html + "\n".join(_get_related_url(metadata)) + "</div>"
166
+ return result["result"], html
167
+
168
+
169
+ def run_search(
170
+ query: str,
171
+ repo_name: str,
172
+ query_options: str,
173
+ start_date: date,
174
+ end_date: date,
175
+ include_comments: bool,
176
+ ) -> Iterable[tuple[Issue, float, str]]:
177
+ query_str, filter = _get_query_str_filter(
178
+ query, repo_name, query_options, start_date, end_date, include_comments
179
+ )
180
  docs = get_similay(query_str, filter)
181
  for doc, score in docs:
182
  text = doc.page_content
 
234
  )
235
  include_comments = st.checkbox(label="Include Issue comments", value=True)
236
 
237
+ submit_col1, submit_col2 = st.columns(2)
238
+ searched = submit_col1.form_submit_button("Search")
239
+ if searched:
240
  st.divider()
241
  st.header("Search Results")
242
  st.divider()
243
  with st.spinner("Searching..."):
244
+ results = run_search(
245
  query, repo_name, query_options, start_date, end_date, include_comments
246
  )
247
  for issue, score, text in results:
 
264
  st.write(f"{labels=}")
265
  # st.markdown(html, unsafe_allow_html=True)
266
  st.divider()
267
+ qa_searched = submit_col2.form_submit_button("QA Search by OpenAI")
268
+ if qa_searched:
269
+ st.divider()
270
+ st.header("QA Search Results by OpenAI GPT-3")
271
+ st.divider()
272
+ with st.spinner("QA Searching..."):
273
+ results = run_qa(
274
+ query, repo_name, query_options, start_date, end_date, include_comments
275
+ )
276
+ answer, html = results
277
+ with st.container():
278
+ st.write(answer)
279
+ st.markdown(html, unsafe_allow_html=True)
280
+ st.divider()
requirements.txt CHANGED
@@ -8,3 +8,4 @@ bitsandbytes
8
  sentence_transformers
9
  streamlit
10
  python-dateutil
 
 
8
  sentence_transformers
9
  streamlit
10
  python-dateutil
11
+ openai