Commit
•
b4d283f
1
Parent(s):
41b224c
fix: avoid global usage
Browse files
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import duckdb
|
2 |
import gradio as gr
|
3 |
import polars as pl
|
@@ -5,7 +7,6 @@ from datasets import load_dataset
|
|
5 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
6 |
from model2vec import StaticModel
|
7 |
|
8 |
-
global ds
|
9 |
global df
|
10 |
|
11 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
@@ -28,14 +29,13 @@ def get_iframe(hub_repo_id):
|
|
28 |
return iframe
|
29 |
|
30 |
|
31 |
-
def load_dataset_from_hub(hub_repo_id):
|
32 |
-
gr.Info("Loading dataset...")
|
33 |
-
global ds
|
34 |
ds = load_dataset(hub_repo_id)
|
35 |
|
36 |
|
37 |
-
def get_columns(split: str):
|
38 |
-
|
39 |
ds_split = ds[split]
|
40 |
return gr.Dropdown(
|
41 |
choices=ds_split.column_names,
|
@@ -45,33 +45,35 @@ def get_columns(split: str):
|
|
45 |
)
|
46 |
|
47 |
|
48 |
-
def get_splits():
|
49 |
-
|
50 |
splits = list(ds.keys())
|
51 |
return gr.Dropdown(
|
52 |
choices=splits, value=splits[0], label="Select a split", visible=True
|
53 |
)
|
54 |
|
55 |
|
56 |
-
|
|
|
57 |
gr.Info("Vectorizing dataset...")
|
58 |
-
|
59 |
-
global ds
|
60 |
df = ds[split].to_polars()
|
61 |
embeddings = model.encode(df[column].cast(str), max_length=512)
|
62 |
-
|
63 |
|
64 |
|
65 |
-
def run_query(query: str, column: str):
|
|
|
|
|
|
|
|
|
66 |
try:
|
67 |
-
global df
|
68 |
-
|
69 |
vector = model.encode(query)
|
70 |
df_results = duckdb.sql(
|
71 |
query=f"""
|
72 |
SELECT *
|
73 |
FROM df
|
74 |
-
ORDER BY array_cosine_distance(
|
75 |
LIMIT 5
|
76 |
"""
|
77 |
).to_df()
|
@@ -134,6 +136,7 @@ with gr.Blocks() as demo:
|
|
134 |
query_input = gr.Textbox(label="Query", visible=False)
|
135 |
|
136 |
btn_run = gr.Button("Search", visible=False)
|
|
|
137 |
results_output = gr.Dataframe(label="Results", visible=False)
|
138 |
|
139 |
search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
|
@@ -143,23 +146,23 @@ with gr.Blocks() as demo:
|
|
143 |
).then(
|
144 |
fn=hide_components,
|
145 |
outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output],
|
146 |
-
).then(fn=get_splits, outputs=split_dropdown).then(
|
147 |
-
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
|
148 |
)
|
149 |
|
150 |
split_dropdown.change(
|
151 |
-
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
|
152 |
)
|
153 |
|
154 |
column_dropdown.change(
|
155 |
fn=partial_hide_components,
|
156 |
outputs=[query_input, btn_run, results_output],
|
157 |
-
).then(fn=
|
158 |
-
fn=show_components, outputs=[query_input, btn_run]
|
159 |
-
)
|
160 |
|
161 |
btn_run.click(
|
162 |
-
fn=run_query,
|
|
|
|
|
163 |
)
|
164 |
|
165 |
demo.launch()
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
|
3 |
import duckdb
|
4 |
import gradio as gr
|
5 |
import polars as pl
|
|
|
7 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
8 |
from model2vec import StaticModel
|
9 |
|
|
|
10 |
global df
|
11 |
|
12 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
|
|
29 |
return iframe
|
30 |
|
31 |
|
32 |
+
def load_dataset_from_hub(hub_repo_id: str):
|
33 |
+
gr.Info(message="Loading dataset...")
|
|
|
34 |
ds = load_dataset(hub_repo_id)
|
35 |
|
36 |
|
37 |
+
def get_columns(hub_repo_id: str, split: str):
|
38 |
+
ds = load_dataset(hub_repo_id)
|
39 |
ds_split = ds[split]
|
40 |
return gr.Dropdown(
|
41 |
choices=ds_split.column_names,
|
|
|
45 |
)
|
46 |
|
47 |
|
48 |
+
def get_splits(hub_repo_id: str):
|
49 |
+
ds = load_dataset(hub_repo_id)
|
50 |
splits = list(ds.keys())
|
51 |
return gr.Dropdown(
|
52 |
choices=splits, value=splits[0], label="Select a split", visible=True
|
53 |
)
|
54 |
|
55 |
|
56 |
+
@lru_cache
|
57 |
+
def vectorize_dataset(hub_repo_id: str, split: str, column: str):
|
58 |
gr.Info("Vectorizing dataset...")
|
59 |
+
ds = load_dataset(hub_repo_id)
|
|
|
60 |
df = ds[split].to_polars()
|
61 |
embeddings = model.encode(df[column].cast(str), max_length=512)
|
62 |
+
return embeddings
|
63 |
|
64 |
|
65 |
+
def run_query(hub_repo_id: str, query: str, split: str, column: str):
|
66 |
+
embeddings = vectorize_dataset(hub_repo_id, split, column)
|
67 |
+
ds = load_dataset(hub_repo_id)
|
68 |
+
df = ds[split].to_polars()
|
69 |
+
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
70 |
try:
|
|
|
|
|
71 |
vector = model.encode(query)
|
72 |
df_results = duckdb.sql(
|
73 |
query=f"""
|
74 |
SELECT *
|
75 |
FROM df
|
76 |
+
ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
|
77 |
LIMIT 5
|
78 |
"""
|
79 |
).to_df()
|
|
|
136 |
query_input = gr.Textbox(label="Query", visible=False)
|
137 |
|
138 |
btn_run = gr.Button("Search", visible=False)
|
139 |
+
|
140 |
results_output = gr.Dataframe(label="Results", visible=False)
|
141 |
|
142 |
search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
|
|
|
146 |
).then(
|
147 |
fn=hide_components,
|
148 |
outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output],
|
149 |
+
).then(fn=get_splits, inputs=search_in, outputs=split_dropdown).then(
|
150 |
+
fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
|
151 |
)
|
152 |
|
153 |
split_dropdown.change(
|
154 |
+
fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
|
155 |
)
|
156 |
|
157 |
column_dropdown.change(
|
158 |
fn=partial_hide_components,
|
159 |
outputs=[query_input, btn_run, results_output],
|
160 |
+
).then(fn=show_components, outputs=[query_input, btn_run])
|
|
|
|
|
161 |
|
162 |
btn_run.click(
|
163 |
+
fn=run_query,
|
164 |
+
inputs=[search_in, query_input, split_dropdown, column_dropdown],
|
165 |
+
outputs=results_output,
|
166 |
)
|
167 |
|
168 |
demo.launch()
|