fizban99
commited on
Commit
•
eaee63c
1
Parent(s):
2da89ac
reranking added
Browse files- .gitignore +2 -0
- app.py +10 -2
- simiandb.py +2 -2
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
*.pyc
|
app.py
CHANGED
@@ -7,17 +7,25 @@ Created on Wed Mar 22 19:59:54 2023
|
|
7 |
import gradio as gr
|
8 |
from simiandb import Simiandb
|
9 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
10 |
|
11 |
|
12 |
|
13 |
|
14 |
model_name = "all-MiniLM-L6-v2"
|
15 |
hf = HuggingFaceEmbeddings(model_name=model_name)
|
|
|
16 |
|
17 |
documentdb = Simiandb("mystore", embedding_function=hf, mode="a")
|
18 |
|
19 |
def search(query):
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
|
22 |
iface = gr.Interface(fn=search, inputs="text", outputs="text")
|
23 |
-
iface.launch()
|
|
|
|
|
|
7 |
import gradio as gr
|
8 |
from simiandb import Simiandb
|
9 |
from langchain.embeddings import HuggingFaceEmbeddings
|
10 |
+
from sentence_transformers import CrossEncoder
|
11 |
|
12 |
|
13 |
|
14 |
|
15 |
model_name = "all-MiniLM-L6-v2"
|
16 |
hf = HuggingFaceEmbeddings(model_name=model_name)
|
17 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
18 |
|
19 |
documentdb = Simiandb("mystore", embedding_function=hf, mode="a")
|
20 |
|
21 |
def search(query):
|
22 |
+
hits = documentdb.similarity_search(query)
|
23 |
+
cross_inp = [[query, hit] for hit in hits]
|
24 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
25 |
+
hits = [hit for _, hit in sorted(zip(cross_scores, hits), reverse=True)]
|
26 |
+
return hits[0]
|
27 |
|
28 |
iface = gr.Interface(fn=search, inputs="text", outputs="text")
|
29 |
+
iface.launch()
|
30 |
+
|
31 |
+
#print(search("what is the balloon boy hoax"))
|
simiandb.py
CHANGED
@@ -178,7 +178,7 @@ class Simiandb():
|
|
178 |
batch = self._vector_table.chunkshape[0]*25
|
179 |
res = np.ascontiguousarray(np.empty(shape=(count,), dtype="float32"))
|
180 |
end = 0
|
181 |
-
|
182 |
while end!=count:
|
183 |
end += batch
|
184 |
end = end if end <= count else count
|
@@ -189,7 +189,7 @@ class Simiandb():
|
|
189 |
|
190 |
indices = np.argpartition(res, -k)[-k:] #from https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
|
191 |
indices = indices[np.argsort(res[indices])[::-1]]
|
192 |
-
|
193 |
return indices
|
194 |
|
195 |
|
|
|
178 |
batch = self._vector_table.chunkshape[0]*25
|
179 |
res = np.ascontiguousarray(np.empty(shape=(count,), dtype="float32"))
|
180 |
end = 0
|
181 |
+
|
182 |
while end!=count:
|
183 |
end += batch
|
184 |
end = end if end <= count else count
|
|
|
189 |
|
190 |
indices = np.argpartition(res, -k)[-k:] #from https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
|
191 |
indices = indices[np.argsort(res[indices])[::-1]]
|
192 |
+
|
193 |
return indices
|
194 |
|
195 |
|