Spaces:
Running
Running
from functools import lru_cache | |
import duckdb | |
import gradio as gr | |
import polars as pl | |
from datasets import load_dataset | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
from model2vec import StaticModel | |
global df | |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model) | |
model_name = "minishlab/potion-base-8M" | |
model = StaticModel.from_pretrained(model_name) | |
def get_iframe(hub_repo_id): | |
if not hub_repo_id: | |
raise ValueError("Hub repo id is required") | |
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" | |
iframe = f""" | |
<iframe | |
src="{url}" | |
frameborder="0" | |
width="100%" | |
height="600px" | |
></iframe> | |
""" | |
return iframe | |
def load_dataset_from_hub(hub_repo_id: str): | |
gr.Info(message="Loading dataset...") | |
ds = load_dataset(hub_repo_id) | |
def get_columns(hub_repo_id: str, split: str): | |
ds = load_dataset(hub_repo_id) | |
ds_split = ds[split] | |
return gr.Dropdown( | |
choices=ds_split.column_names, | |
value=ds_split.column_names[0], | |
label="Select a column", | |
visible=True, | |
) | |
def get_splits(hub_repo_id: str): | |
ds = load_dataset(hub_repo_id) | |
splits = list(ds.keys()) | |
return gr.Dropdown( | |
choices=splits, value=splits[0], label="Select a split", visible=True | |
) | |
def vectorize_dataset(hub_repo_id: str, split: str, column: str): | |
gr.Info("Vectorizing dataset...") | |
ds = load_dataset(hub_repo_id) | |
df = ds[split].to_polars() | |
embeddings = model.encode(df[column].cast(str), max_length=512) | |
return embeddings | |
def run_query(hub_repo_id: str, query: str, split: str, column: str): | |
embeddings = vectorize_dataset(hub_repo_id, split, column) | |
ds = load_dataset(hub_repo_id) | |
df = ds[split].to_polars() | |
df = df.with_columns(pl.Series(embeddings).alias("embeddings")) | |
try: | |
vector = model.encode(query) | |
df_results = duckdb.sql( | |
query=f""" | |
SELECT * | |
FROM df | |
ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256]) | |
LIMIT 5 | |
""" | |
).to_df() | |
return gr.Dataframe(df_results, visible=True) | |
except Exception as e: | |
raise gr.Error(f"Error running query: {e}") | |
def hide_components(): | |
return [ | |
gr.Dropdown(visible=False), | |
gr.Dropdown(visible=False), | |
gr.Textbox(visible=False), | |
gr.Button(visible=False), | |
gr.Dataframe(visible=False), | |
] | |
def partial_hide_components(): | |
return [ | |
gr.Textbox(visible=False), | |
gr.Button(visible=False), | |
gr.Dataframe(visible=False), | |
] | |
def show_components(): | |
return [ | |
gr.Textbox(visible=True, label="Query"), | |
gr.Button(visible=True, value="Search"), | |
] | |
with gr.Blocks() as demo: | |
gr.HTML( | |
""" | |
<h1>Vector Search any Hugging Face Dataset</h1> | |
<p> | |
This app allows you to vector search any Hugging Face dataset. | |
You can search for the nearest neighbors of a query vector, or | |
perform a similarity search on a dataframe. | |
</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
search_in = HuggingfaceHubSearch( | |
label="Search Huggingface Hub", | |
placeholder="Search for models on Huggingface", | |
search_type="dataset", | |
sumbit_on_select=True, | |
) | |
with gr.Row(): | |
search_out = gr.HTML(label="Search Results") | |
with gr.Row(): | |
split_dropdown = gr.Dropdown(label="Select a split", visible=False) | |
column_dropdown = gr.Dropdown(label="Select a column", visible=False) | |
with gr.Row(): | |
query_input = gr.Textbox(label="Query", visible=False) | |
btn_run = gr.Button("Search", visible=False) | |
results_output = gr.Dataframe(label="Results", visible=False) | |
search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then( | |
fn=load_dataset_from_hub, | |
inputs=search_in, | |
show_progress=True, | |
).then( | |
fn=hide_components, | |
outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output], | |
).then(fn=get_splits, inputs=search_in, outputs=split_dropdown).then( | |
fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown | |
) | |
split_dropdown.change( | |
fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown | |
) | |
column_dropdown.change( | |
fn=partial_hide_components, | |
outputs=[query_input, btn_run, results_output], | |
).then(fn=show_components, outputs=[query_input, btn_run]) | |
btn_run.click( | |
fn=run_query, | |
inputs=[search_in, query_input, split_dropdown, column_dropdown], | |
outputs=results_output, | |
) | |
demo.launch() | |