Commit
·
5f76c1a
1
Parent(s):
fbdb332
feat: use SOTA model
Browse files
app.py
CHANGED
|
@@ -9,7 +9,7 @@ global ds
|
|
| 9 |
global df
|
| 10 |
|
| 11 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
| 12 |
-
model_name = "minishlab/
|
| 13 |
model = StaticModel.from_pretrained(model_name)
|
| 14 |
|
| 15 |
|
|
@@ -53,7 +53,7 @@ def vectorize_dataset(split: str, column: str):
|
|
| 53 |
global df
|
| 54 |
global ds
|
| 55 |
df = ds[split].to_polars()
|
| 56 |
-
embeddings = model.encode(df[column])
|
| 57 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
| 58 |
|
| 59 |
|
|
@@ -64,7 +64,7 @@ def run_query(query: str):
|
|
| 64 |
query=f"""
|
| 65 |
SELECT *
|
| 66 |
FROM df
|
| 67 |
-
ORDER BY
|
| 68 |
LIMIT 5
|
| 69 |
"""
|
| 70 |
).to_df()
|
|
@@ -91,18 +91,16 @@ with gr.Blocks() as demo:
|
|
| 91 |
)
|
| 92 |
with gr.Row():
|
| 93 |
search_out = gr.HTML(label="Search Results")
|
| 94 |
-
search_in.submit(get_iframe, inputs=search_in, outputs=search_out)
|
| 95 |
-
|
| 96 |
-
btn_load_dataset = gr.Button("Load Dataset")
|
| 97 |
|
| 98 |
with gr.Row(variant="panel"):
|
| 99 |
split_dropdown = gr.Dropdown(label="Select a split")
|
| 100 |
column_dropdown = gr.Dropdown(label="Select a column")
|
| 101 |
with gr.Row(variant="panel"):
|
| 102 |
query_input = gr.Textbox(label="Query")
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
| 106 |
).then(fn=get_splits, outputs=split_dropdown).then(
|
| 107 |
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
|
| 108 |
)
|
|
|
|
| 9 |
global df
|
| 10 |
|
| 11 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
| 12 |
+
model_name = "minishlab/potion-base-8M"
|
| 13 |
model = StaticModel.from_pretrained(model_name)
|
| 14 |
|
| 15 |
|
|
|
|
| 53 |
global df
|
| 54 |
global ds
|
| 55 |
df = ds[split].to_polars()
|
| 56 |
+
embeddings = model.encode(df[column], max_length=512 * 4)
|
| 57 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
| 58 |
|
| 59 |
|
|
|
|
| 64 |
query=f"""
|
| 65 |
SELECT *
|
| 66 |
FROM df
|
| 67 |
+
ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
|
| 68 |
LIMIT 5
|
| 69 |
"""
|
| 70 |
).to_df()
|
|
|
|
| 91 |
)
|
| 92 |
with gr.Row():
|
| 93 |
search_out = gr.HTML(label="Search Results")
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
with gr.Row(variant="panel"):
|
| 96 |
split_dropdown = gr.Dropdown(label="Select a split")
|
| 97 |
column_dropdown = gr.Dropdown(label="Select a column")
|
| 98 |
with gr.Row(variant="panel"):
|
| 99 |
query_input = gr.Textbox(label="Query")
|
| 100 |
+
search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
|
| 101 |
+
fn=load_dataset_from_hub,
|
| 102 |
+
inputs=search_in,
|
| 103 |
+
show_progress=True,
|
| 104 |
).then(fn=get_splits, outputs=split_dropdown).then(
|
| 105 |
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
|
| 106 |
)
|
demo.py
CHANGED
|
@@ -4,20 +4,20 @@ from datasets import load_dataset
|
|
| 4 |
from model2vec import StaticModel
|
| 5 |
|
| 6 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
| 7 |
-
model_name = "minishlab/
|
| 8 |
model = StaticModel.from_pretrained(model_name)
|
| 9 |
|
| 10 |
# Make embeddings
|
| 11 |
ds = load_dataset("fka/awesome-chatgpt-prompts")
|
| 12 |
df = ds["train"].to_polars()
|
| 13 |
-
embeddings = model.encode(df["
|
| 14 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
| 15 |
-
vector = model.encode("
|
| 16 |
duckdb.sql(
|
| 17 |
query=f"""
|
| 18 |
SELECT *
|
| 19 |
FROM df
|
| 20 |
-
ORDER BY
|
| 21 |
-
LIMIT
|
| 22 |
"""
|
| 23 |
).show()
|
|
|
|
| 4 |
from model2vec import StaticModel
|
| 5 |
|
| 6 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
| 7 |
+
model_name = "minishlab/potion-base-8M"
|
| 8 |
model = StaticModel.from_pretrained(model_name)
|
| 9 |
|
| 10 |
# Make embeddings
|
| 11 |
ds = load_dataset("fka/awesome-chatgpt-prompts")
|
| 12 |
df = ds["train"].to_polars()
|
| 13 |
+
embeddings = model.encode(df["act"])
|
| 14 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
| 15 |
+
vector = model.encode("An Ethereum Developer", show_progress_bar=True)
|
| 16 |
duckdb.sql(
|
| 17 |
query=f"""
|
| 18 |
SELECT *
|
| 19 |
FROM df
|
| 20 |
+
ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
|
| 21 |
+
LIMIT 10
|
| 22 |
"""
|
| 23 |
).show()
|