Spaces:
Running
Running
import duckdb | |
import gradio as gr | |
import polars as pl | |
from datasets import load_dataset | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
from model2vec import StaticModel | |
global ds | |
global df | |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model) | |
model_name = "minishlab/M2V_multilingual_output" | |
model = StaticModel.from_pretrained(model_name) | |
def get_iframe(hub_repo_id): | |
if not hub_repo_id: | |
raise ValueError("Hub repo id is required") | |
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" | |
iframe = f""" | |
<iframe | |
src="{url}" | |
frameborder="0" | |
width="100%" | |
height="600px" | |
></iframe> | |
""" | |
return iframe | |
def load_dataset_from_hub(hub_repo_id): | |
global ds | |
ds = load_dataset(hub_repo_id) | |
def get_columns(split: str): | |
global ds | |
ds_split = ds[split] | |
return gr.Dropdown( | |
choices=ds_split.column_names, | |
value=ds_split.column_names[0], | |
label="Select a column", | |
) | |
def get_splits(): | |
global ds | |
splits = list(ds.keys()) | |
return gr.Dropdown(choices=splits, value=splits[0], label="Select a split") | |
def vectorize_dataset(split: str, column: str): | |
global df | |
global ds | |
df = ds[split].to_polars() | |
embeddings = model.encode(df[column]) | |
df = df.with_columns(pl.Series(embeddings).alias("embeddings")) | |
def run_query(query: str): | |
global df | |
vector = model.encode(query) | |
return duckdb.sql( | |
query=f""" | |
SELECT * | |
FROM df | |
ORDER BY array_distance(embeddings, {vector.tolist()}::FLOAT[256]) | |
LIMIT 5 | |
""" | |
).to_df() | |
with gr.Blocks() as demo: | |
gr.HTML( | |
""" | |
<h1>Vector Search any Hugging Face Dataset</h1> | |
<p> | |
This app allows you to vector search any Hugging Face dataset. | |
You can search for the nearest neighbors of a query vector, or | |
perform a similarity search on a dataframe. | |
</p> | |
<p> | |
This app uses the <a href="https://huggingface.co/minishlab/M2V_multilingual_output">M2V_multilingual_output</a> model from the Hugging Face Hub. | |
</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
search_in = HuggingfaceHubSearch( | |
label="Search Huggingface Hub", | |
placeholder="Search for models on Huggingface", | |
search_type="dataset", | |
sumbit_on_select=True, | |
) | |
with gr.Row(): | |
search_out = gr.HTML(label="Search Results") | |
search_in.submit(get_iframe, inputs=search_in, outputs=search_out) | |
btn_load_dataset = gr.Button("Load Dataset") | |
with gr.Row(variant="panel"): | |
split_dropdown = gr.Dropdown(label="Select a split") | |
column_dropdown = gr.Dropdown(label="Select a column") | |
btn_vectorize_dataset = gr.Button("Vectorize") | |
btn_load_dataset.click( | |
load_dataset_from_hub, inputs=search_in, show_progress=True | |
).then(fn=get_splits, outputs=split_dropdown).then( | |
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown | |
) | |
split_dropdown.change( | |
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown | |
) | |
btn_vectorize_dataset.click( | |
fn=vectorize_dataset, | |
inputs=[split_dropdown, column_dropdown], | |
show_progress=True, | |
) | |
with gr.Row(variant="panel"): | |
query_input = gr.Textbox(label="Query") | |
btn_run = gr.Button("Run") | |
results_output = gr.Dataframe(label="Results") | |
btn_run.click(fn=run_query, inputs=query_input, outputs=results_output) | |
demo.launch() | |