Commit
•
5f76c1a
1
Parent(s):
fbdb332
feat: use SOTA model
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ global ds
|
|
9 |
global df
|
10 |
|
11 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
12 |
-
model_name = "minishlab/
|
13 |
model = StaticModel.from_pretrained(model_name)
|
14 |
|
15 |
|
@@ -53,7 +53,7 @@ def vectorize_dataset(split: str, column: str):
|
|
53 |
global df
|
54 |
global ds
|
55 |
df = ds[split].to_polars()
|
56 |
-
embeddings = model.encode(df[column])
|
57 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
58 |
|
59 |
|
@@ -64,7 +64,7 @@ def run_query(query: str):
|
|
64 |
query=f"""
|
65 |
SELECT *
|
66 |
FROM df
|
67 |
-
ORDER BY
|
68 |
LIMIT 5
|
69 |
"""
|
70 |
).to_df()
|
@@ -91,18 +91,16 @@ with gr.Blocks() as demo:
|
|
91 |
)
|
92 |
with gr.Row():
|
93 |
search_out = gr.HTML(label="Search Results")
|
94 |
-
search_in.submit(get_iframe, inputs=search_in, outputs=search_out)
|
95 |
-
|
96 |
-
btn_load_dataset = gr.Button("Load Dataset")
|
97 |
|
98 |
with gr.Row(variant="panel"):
|
99 |
split_dropdown = gr.Dropdown(label="Select a split")
|
100 |
column_dropdown = gr.Dropdown(label="Select a column")
|
101 |
with gr.Row(variant="panel"):
|
102 |
query_input = gr.Textbox(label="Query")
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
106 |
).then(fn=get_splits, outputs=split_dropdown).then(
|
107 |
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
|
108 |
)
|
|
|
9 |
global df
|
10 |
|
11 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
12 |
+
model_name = "minishlab/potion-base-8M"
|
13 |
model = StaticModel.from_pretrained(model_name)
|
14 |
|
15 |
|
|
|
53 |
global df
|
54 |
global ds
|
55 |
df = ds[split].to_polars()
|
56 |
+
embeddings = model.encode(df[column], max_length=512 * 4)
|
57 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
58 |
|
59 |
|
|
|
64 |
query=f"""
|
65 |
SELECT *
|
66 |
FROM df
|
67 |
+
ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
|
68 |
LIMIT 5
|
69 |
"""
|
70 |
).to_df()
|
|
|
91 |
)
|
92 |
with gr.Row():
|
93 |
search_out = gr.HTML(label="Search Results")
|
|
|
|
|
|
|
94 |
|
95 |
with gr.Row(variant="panel"):
|
96 |
split_dropdown = gr.Dropdown(label="Select a split")
|
97 |
column_dropdown = gr.Dropdown(label="Select a column")
|
98 |
with gr.Row(variant="panel"):
|
99 |
query_input = gr.Textbox(label="Query")
|
100 |
+
search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
|
101 |
+
fn=load_dataset_from_hub,
|
102 |
+
inputs=search_in,
|
103 |
+
show_progress=True,
|
104 |
).then(fn=get_splits, outputs=split_dropdown).then(
|
105 |
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
|
106 |
)
|
demo.py
CHANGED
@@ -4,20 +4,20 @@ from datasets import load_dataset
|
|
4 |
from model2vec import StaticModel
|
5 |
|
6 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
7 |
-
model_name = "minishlab/
|
8 |
model = StaticModel.from_pretrained(model_name)
|
9 |
|
10 |
# Make embeddings
|
11 |
ds = load_dataset("fka/awesome-chatgpt-prompts")
|
12 |
df = ds["train"].to_polars()
|
13 |
-
embeddings = model.encode(df["
|
14 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
15 |
-
vector = model.encode("
|
16 |
duckdb.sql(
|
17 |
query=f"""
|
18 |
SELECT *
|
19 |
FROM df
|
20 |
-
ORDER BY
|
21 |
-
LIMIT
|
22 |
"""
|
23 |
).show()
|
|
|
4 |
from model2vec import StaticModel
|
5 |
|
6 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
7 |
+
model_name = "minishlab/potion-base-8M"
|
8 |
model = StaticModel.from_pretrained(model_name)
|
9 |
|
10 |
# Make embeddings
|
11 |
ds = load_dataset("fka/awesome-chatgpt-prompts")
|
12 |
df = ds["train"].to_polars()
|
13 |
+
embeddings = model.encode(df["act"])
|
14 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
15 |
+
vector = model.encode("An Ethereum Developer", show_progress_bar=True)
|
16 |
duckdb.sql(
|
17 |
query=f"""
|
18 |
SELECT *
|
19 |
FROM df
|
20 |
+
ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
|
21 |
+
LIMIT 10
|
22 |
"""
|
23 |
).show()
|