Spaces:
Running
on
Zero
Running
on
Zero
Update main.py
Browse files
main.py
CHANGED
@@ -1,56 +1,78 @@
|
|
1 |
-
import json
|
2 |
|
3 |
-
import gradio
|
4 |
-
import datasets
|
5 |
|
6 |
-
import numpy
|
7 |
-
import pandas
|
8 |
-
import sentence_transformers
|
9 |
-
import faiss
|
10 |
|
11 |
-
model
|
|
|
12 |
|
13 |
-
|
|
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
data = full_data[~pandas.Series(filter)]
|
17 |
data.reset_index(inplace=True)
|
18 |
|
19 |
-
|
20 |
-
index = faiss.IndexFlatL2(
|
21 |
index.metric_type = faiss.METRIC_INNER_PRODUCT
|
|
|
22 |
|
23 |
-
vectors = numpy.stack(data['embedding'].to_list(), axis=0)
|
24 |
-
|
25 |
-
index.add(vectors)
|
26 |
|
27 |
-
|
|
|
28 |
query = numpy.expand_dims(model.encode(query), axis=0)
|
29 |
_, I = index.search(query, k)
|
30 |
top_five = data.loc[I[0]]
|
31 |
search_results = ""
|
32 |
|
33 |
for i in range(k):
|
34 |
-
search_results +=
|
35 |
if top_five["pub_url"].values[i] is not None:
|
36 |
search_results += "[Full Text](" + top_five["pub_url"].values[i] + ") "
|
37 |
if top_five["citedby_url"].values[i] is not None:
|
38 |
-
search_results +=
|
|
|
|
|
39 |
if top_five["url_related_articles"].values[i] is not None:
|
40 |
-
search_results +=
|
|
|
|
|
|
|
|
|
|
|
41 |
search_results += "\n\n```bibtex\n"
|
42 |
-
search_results +=
|
|
|
|
|
|
|
|
|
|
|
43 |
search_results += "```\n"
|
44 |
return search_results
|
45 |
|
46 |
|
47 |
with gradio.Blocks() as demo:
|
48 |
with gradio.Group():
|
49 |
-
query = gradio.Textbox(
|
|
|
|
|
50 |
with gradio.Accordion("Settings", open=False):
|
51 |
k = gradio.Number(5.0, label="Number of results", precision=0)
|
52 |
results = gradio.Markdown()
|
53 |
query.change(fn=search, inputs=[query, k], outputs=results)
|
54 |
k.change(fn=search, inputs=[query, k], outputs=results)
|
55 |
|
56 |
-
demo.launch(debug=True)
|
|
|
1 |
+
import json # For stringifying a dict
|
2 |
|
3 |
+
import gradio # GUI framework
|
4 |
+
import datasets # Used to load publication dataset
|
5 |
|
6 |
+
import numpy # For a few simple matrix operations
|
7 |
+
import pandas # Needed for operating on dataset
|
8 |
+
import sentence_transformers # Needed for query embedding
|
9 |
+
import faiss # Needed for fast similarity search
|
10 |
|
11 |
+
# Load the model for later use in embeddings
|
12 |
+
model = sentence_transformers.SentenceTransformer("allenai-specter")
|
13 |
|
14 |
+
# Load the dataset and convert to pandas
|
15 |
+
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
|
16 |
|
17 |
+
# Define the base URL for Google Scholar
|
18 |
+
SCHOLAR_URL = "https://scholar.google.com"
|
19 |
+
|
20 |
+
# Filter out any publications without an abstract
|
21 |
+
filter = [
|
22 |
+
'"abstract": null' in json.dumps(bibdict)
|
23 |
+
for bibdict in full_data["bib_dict"].values
|
24 |
+
]
|
25 |
data = full_data[~pandas.Series(filter)]
|
26 |
data.reset_index(inplace=True)
|
27 |
|
28 |
+
# Create a FAISS index for fast similarity search
|
29 |
+
index = faiss.IndexFlatL2(len(data["embedding"][0]))
|
30 |
index.metric_type = faiss.METRIC_INNER_PRODUCT
|
31 |
+
index.add(numpy.stack(data["embedding"].to_list(), axis=0))
|
32 |
|
|
|
|
|
|
|
33 |
|
34 |
+
# Define the search function
|
35 |
+
def search(query: str, k: int):
|
36 |
query = numpy.expand_dims(model.encode(query), axis=0)
|
37 |
_, I = index.search(query, k)
|
38 |
top_five = data.loc[I[0]]
|
39 |
search_results = ""
|
40 |
|
41 |
for i in range(k):
|
42 |
+
search_results += "### " + top_five["bib_dict"].values[i]["title"] + "\n\n"
|
43 |
if top_five["pub_url"].values[i] is not None:
|
44 |
search_results += "[Full Text](" + top_five["pub_url"].values[i] + ") "
|
45 |
if top_five["citedby_url"].values[i] is not None:
|
46 |
+
search_results += (
|
47 |
+
"[Cited By](" + SCHOLAR_URL + top_five["citedby_url"].values[i] + ") "
|
48 |
+
)
|
49 |
if top_five["url_related_articles"].values[i] is not None:
|
50 |
+
search_results += (
|
51 |
+
"[Related Articles]("
|
52 |
+
+ SCHOLAR_URL
|
53 |
+
+ top_five["url_related_articles"].values[i]
|
54 |
+
+ ") "
|
55 |
+
)
|
56 |
search_results += "\n\n```bibtex\n"
|
57 |
+
search_results += (
|
58 |
+
json.dumps(top_five["bibtex"].values[i], indent=4)
|
59 |
+
.replace("\\n", "\n")
|
60 |
+
.replace("\\t", "\t")
|
61 |
+
.strip('"')
|
62 |
+
)
|
63 |
search_results += "```\n"
|
64 |
return search_results
|
65 |
|
66 |
|
67 |
with gradio.Blocks() as demo:
|
68 |
with gradio.Group():
|
69 |
+
query = gradio.Textbox(
|
70 |
+
placeholder="Enter search terms...", show_label=False, lines=1, max_lines=1
|
71 |
+
)
|
72 |
with gradio.Accordion("Settings", open=False):
|
73 |
k = gradio.Number(5.0, label="Number of results", precision=0)
|
74 |
results = gradio.Markdown()
|
75 |
query.change(fn=search, inputs=[query, k], outputs=results)
|
76 |
k.change(fn=search, inputs=[query, k], outputs=results)
|
77 |
|
78 |
+
demo.launch(debug=True)
|