Spaces:
Sleeping
Sleeping
support RAG refs #5
Browse files- app.py +101 -5
- requirements.txt +1 -0
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from datetime import datetime, date, timedelta
|
2 |
from typing import Iterable
|
3 |
import streamlit as st
|
@@ -6,6 +7,9 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|
6 |
from langchain.vectorstores import Qdrant
|
7 |
from qdrant_client import QdrantClient
|
8 |
from qdrant_client.http.models import Filter, FieldCondition, MatchValue, Range
|
|
|
|
|
|
|
9 |
from config import DB_CONFIG
|
10 |
from model import Issue
|
11 |
|
@@ -23,7 +27,14 @@ def load_embeddings():
|
|
23 |
return embeddings
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
EMBEDDINGS = load_embeddings()
|
|
|
27 |
|
28 |
|
29 |
def make_filter_obj(options: list[dict[str]]):
|
@@ -67,14 +78,46 @@ def get_similay(query: str, filter: Filter):
|
|
67 |
return docs
|
68 |
|
69 |
|
70 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
query: str,
|
72 |
repo_name: str,
|
73 |
query_options: str,
|
74 |
start_date: date,
|
75 |
end_date: date,
|
76 |
include_comments: bool,
|
77 |
-
) ->
|
78 |
options = [{"key": "metadata.repo_name", "value": repo_name}]
|
79 |
if start_date is not None and end_date is not None:
|
80 |
options.append(
|
@@ -96,6 +139,44 @@ def main(
|
|
96 |
if query_options == "Empty":
|
97 |
query_options = ""
|
98 |
query_str = f"{query_options}{query}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
docs = get_similay(query_str, filter)
|
100 |
for doc, score in docs:
|
101 |
text = doc.page_content
|
@@ -153,13 +234,14 @@ with st.form("my_form"):
|
|
153 |
)
|
154 |
include_comments = st.checkbox(label="Include Issue comments", value=True)
|
155 |
|
156 |
-
|
157 |
-
|
|
|
158 |
st.divider()
|
159 |
st.header("Search Results")
|
160 |
st.divider()
|
161 |
with st.spinner("Searching..."):
|
162 |
-
results =
|
163 |
query, repo_name, query_options, start_date, end_date, include_comments
|
164 |
)
|
165 |
for issue, score, text in results:
|
@@ -182,3 +264,17 @@ with st.form("my_form"):
|
|
182 |
st.write(f"{labels=}")
|
183 |
# st.markdown(html, unsafe_allow_html=True)
|
184 |
st.divider()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import time
|
2 |
from datetime import datetime, date, timedelta
|
3 |
from typing import Iterable
|
4 |
import streamlit as st
|
|
|
7 |
from langchain.vectorstores import Qdrant
|
8 |
from qdrant_client import QdrantClient
|
9 |
from qdrant_client.http.models import Filter, FieldCondition, MatchValue, Range
|
10 |
+
from langchain.chains import RetrievalQA
|
11 |
+
from openai.error import InvalidRequestError
|
12 |
+
from langchain.chat_models import ChatOpenAI
|
13 |
from config import DB_CONFIG
|
14 |
from model import Issue
|
15 |
|
|
|
27 |
return embeddings
|
28 |
|
29 |
|
30 |
+
@st.cache_resource
|
31 |
+
def llm_model(model="gpt-3.5-turbo", temperature=0.2):
|
32 |
+
llm = ChatOpenAI(model=model, temperature=temperature)
|
33 |
+
return llm
|
34 |
+
|
35 |
+
|
36 |
EMBEDDINGS = load_embeddings()
|
37 |
+
LLM = llm_model()
|
38 |
|
39 |
|
40 |
def make_filter_obj(options: list[dict[str]]):
|
|
|
78 |
return docs
|
79 |
|
80 |
|
81 |
+
def get_retrieval_qa(filter: Filter):
|
82 |
+
db_url, db_api_key, db_collection_name = DB_CONFIG
|
83 |
+
client = QdrantClient(url=db_url, api_key=db_api_key)
|
84 |
+
db = Qdrant(
|
85 |
+
client=client, collection_name=db_collection_name, embeddings=EMBEDDINGS
|
86 |
+
)
|
87 |
+
retriever = db.as_retriever(
|
88 |
+
search_kwargs={
|
89 |
+
"filter": filter,
|
90 |
+
}
|
91 |
+
)
|
92 |
+
result = RetrievalQA.from_chain_type(
|
93 |
+
llm=LLM,
|
94 |
+
chain_type="stuff",
|
95 |
+
retriever=retriever,
|
96 |
+
return_source_documents=True,
|
97 |
+
)
|
98 |
+
return result
|
99 |
+
|
100 |
+
|
101 |
+
def _get_related_url(metadata) -> Iterable[str]:
|
102 |
+
urls = set()
|
103 |
+
for m in metadata:
|
104 |
+
url = m["url"]
|
105 |
+
if url in urls:
|
106 |
+
continue
|
107 |
+
urls.add(url)
|
108 |
+
created_at = datetime.fromtimestamp(m["created_at"])
|
109 |
+
# print(m)
|
110 |
+
yield f'<p>URL: <a href="{url}">{url}</a> (created: {created_at:%Y-%m-%d})</p>'
|
111 |
+
|
112 |
+
|
113 |
+
def _get_query_str_filter(
|
114 |
query: str,
|
115 |
repo_name: str,
|
116 |
query_options: str,
|
117 |
start_date: date,
|
118 |
end_date: date,
|
119 |
include_comments: bool,
|
120 |
+
) -> tuple[str, Filter]:
|
121 |
options = [{"key": "metadata.repo_name", "value": repo_name}]
|
122 |
if start_date is not None and end_date is not None:
|
123 |
options.append(
|
|
|
139 |
if query_options == "Empty":
|
140 |
query_options = ""
|
141 |
query_str = f"{query_options}{query}"
|
142 |
+
return query_str, filter
|
143 |
+
|
144 |
+
|
145 |
+
def run_qa(
|
146 |
+
query: str,
|
147 |
+
repo_name: str,
|
148 |
+
query_options: str,
|
149 |
+
start_date: date,
|
150 |
+
end_date: date,
|
151 |
+
include_comments: bool,
|
152 |
+
) -> tuple[str, str]:
|
153 |
+
now = time()
|
154 |
+
query_str, filter = _get_query_str_filter(
|
155 |
+
query, repo_name, query_options, start_date, end_date, include_comments
|
156 |
+
)
|
157 |
+
qa = get_retrieval_qa(filter)
|
158 |
+
try:
|
159 |
+
result = qa(query_str)
|
160 |
+
except InvalidRequestError as e:
|
161 |
+
return "ๅ็ญใ่ฆใคใใใพใใใงใใใๅฅใช่ณชๅใใใฆใฟใฆใใ ใใ", str(e)
|
162 |
+
else:
|
163 |
+
metadata = [s.metadata for s in result["source_documents"]]
|
164 |
+
sec_html = f"<p>ๅฎ่กๆ้: {(time() - now):.2f}็ง</p>"
|
165 |
+
html = "<div>" + sec_html + "\n".join(_get_related_url(metadata)) + "</div>"
|
166 |
+
return result["result"], html
|
167 |
+
|
168 |
+
|
169 |
+
def run_search(
|
170 |
+
query: str,
|
171 |
+
repo_name: str,
|
172 |
+
query_options: str,
|
173 |
+
start_date: date,
|
174 |
+
end_date: date,
|
175 |
+
include_comments: bool,
|
176 |
+
) -> Iterable[tuple[Issue, float, str]]:
|
177 |
+
query_str, filter = _get_query_str_filter(
|
178 |
+
query, repo_name, query_options, start_date, end_date, include_comments
|
179 |
+
)
|
180 |
docs = get_similay(query_str, filter)
|
181 |
for doc, score in docs:
|
182 |
text = doc.page_content
|
|
|
234 |
)
|
235 |
include_comments = st.checkbox(label="Include Issue comments", value=True)
|
236 |
|
237 |
+
submit_col1, submit_col2 = st.columns(2)
|
238 |
+
searched = submit_col1.form_submit_button("Search")
|
239 |
+
if searched:
|
240 |
st.divider()
|
241 |
st.header("Search Results")
|
242 |
st.divider()
|
243 |
with st.spinner("Searching..."):
|
244 |
+
results = run_search(
|
245 |
query, repo_name, query_options, start_date, end_date, include_comments
|
246 |
)
|
247 |
for issue, score, text in results:
|
|
|
264 |
st.write(f"{labels=}")
|
265 |
# st.markdown(html, unsafe_allow_html=True)
|
266 |
st.divider()
|
267 |
+
qa_searched = submit_col2.form_submit_button("QA Search by OpenAI")
|
268 |
+
if qa_searched:
|
269 |
+
st.divider()
|
270 |
+
st.header("QA Search Results by OpenAI GPT-3")
|
271 |
+
st.divider()
|
272 |
+
with st.spinner("QA Searching..."):
|
273 |
+
results = run_qa(
|
274 |
+
query, repo_name, query_options, start_date, end_date, include_comments
|
275 |
+
)
|
276 |
+
answer, html = results
|
277 |
+
with st.container():
|
278 |
+
st.write(answer)
|
279 |
+
st.markdown(html, unsafe_allow_html=True)
|
280 |
+
st.divider()
|
requirements.txt
CHANGED
@@ -8,3 +8,4 @@ bitsandbytes
|
|
8 |
sentence_transformers
|
9 |
streamlit
|
10 |
python-dateutil
|
|
|
|
8 |
sentence_transformers
|
9 |
streamlit
|
10 |
python-dateutil
|
11 |
+
openai
|