davidberenstein1957's picture
fix: avoid global usage
b4d283f
from functools import lru_cache
import duckdb
import gradio as gr
import polars as pl
from datasets import load_dataset
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from model2vec import StaticModel
global df
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
model_name = "minishlab/potion-base-8M"
model = StaticModel.from_pretrained(model_name)
def get_iframe(hub_repo_id):
if not hub_repo_id:
raise ValueError("Hub repo id is required")
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
iframe = f"""
<iframe
src="{url}"
frameborder="0"
width="100%"
height="600px"
></iframe>
"""
return iframe
def load_dataset_from_hub(hub_repo_id: str):
gr.Info(message="Loading dataset...")
ds = load_dataset(hub_repo_id)
def get_columns(hub_repo_id: str, split: str):
ds = load_dataset(hub_repo_id)
ds_split = ds[split]
return gr.Dropdown(
choices=ds_split.column_names,
value=ds_split.column_names[0],
label="Select a column",
visible=True,
)
def get_splits(hub_repo_id: str):
ds = load_dataset(hub_repo_id)
splits = list(ds.keys())
return gr.Dropdown(
choices=splits, value=splits[0], label="Select a split", visible=True
)
@lru_cache
def vectorize_dataset(hub_repo_id: str, split: str, column: str):
gr.Info("Vectorizing dataset...")
ds = load_dataset(hub_repo_id)
df = ds[split].to_polars()
embeddings = model.encode(df[column].cast(str), max_length=512)
return embeddings
def run_query(hub_repo_id: str, query: str, split: str, column: str):
embeddings = vectorize_dataset(hub_repo_id, split, column)
ds = load_dataset(hub_repo_id)
df = ds[split].to_polars()
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
try:
vector = model.encode(query)
df_results = duckdb.sql(
query=f"""
SELECT *
FROM df
ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
LIMIT 5
"""
).to_df()
return gr.Dataframe(df_results, visible=True)
except Exception as e:
raise gr.Error(f"Error running query: {e}")
def hide_components():
return [
gr.Dropdown(visible=False),
gr.Dropdown(visible=False),
gr.Textbox(visible=False),
gr.Button(visible=False),
gr.Dataframe(visible=False),
]
def partial_hide_components():
return [
gr.Textbox(visible=False),
gr.Button(visible=False),
gr.Dataframe(visible=False),
]
def show_components():
return [
gr.Textbox(visible=True, label="Query"),
gr.Button(visible=True, value="Search"),
]
with gr.Blocks() as demo:
gr.HTML(
"""
<h1>Vector Search any Hugging Face Dataset</h1>
<p>
This app allows you to vector search any Hugging Face dataset.
You can search for the nearest neighbors of a query vector, or
perform a similarity search on a dataframe.
</p>
"""
)
with gr.Row():
with gr.Column():
search_in = HuggingfaceHubSearch(
label="Search Huggingface Hub",
placeholder="Search for models on Huggingface",
search_type="dataset",
sumbit_on_select=True,
)
with gr.Row():
search_out = gr.HTML(label="Search Results")
with gr.Row():
split_dropdown = gr.Dropdown(label="Select a split", visible=False)
column_dropdown = gr.Dropdown(label="Select a column", visible=False)
with gr.Row():
query_input = gr.Textbox(label="Query", visible=False)
btn_run = gr.Button("Search", visible=False)
results_output = gr.Dataframe(label="Results", visible=False)
search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
fn=load_dataset_from_hub,
inputs=search_in,
show_progress=True,
).then(
fn=hide_components,
outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output],
).then(fn=get_splits, inputs=search_in, outputs=split_dropdown).then(
fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
)
split_dropdown.change(
fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
)
column_dropdown.change(
fn=partial_hide_components,
outputs=[query_input, btn_run, results_output],
).then(fn=show_components, outputs=[query_input, btn_run])
btn_run.click(
fn=run_query,
inputs=[search_in, query_input, split_dropdown, column_dropdown],
outputs=results_output,
)
demo.launch()