Spaces:
Build error
Build error
Commit
·
b4d283f
1
Parent(s):
41b224c
fix: avoid global usage
Browse files
app.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import duckdb
|
| 2 |
import gradio as gr
|
| 3 |
import polars as pl
|
|
@@ -5,7 +7,6 @@ from datasets import load_dataset
|
|
| 5 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
| 6 |
from model2vec import StaticModel
|
| 7 |
|
| 8 |
-
global ds
|
| 9 |
global df
|
| 10 |
|
| 11 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
|
@@ -28,14 +29,13 @@ def get_iframe(hub_repo_id):
|
|
| 28 |
return iframe
|
| 29 |
|
| 30 |
|
| 31 |
-
def load_dataset_from_hub(hub_repo_id):
|
| 32 |
-
gr.Info("Loading dataset...")
|
| 33 |
-
global ds
|
| 34 |
ds = load_dataset(hub_repo_id)
|
| 35 |
|
| 36 |
|
| 37 |
-
def get_columns(split: str):
|
| 38 |
-
|
| 39 |
ds_split = ds[split]
|
| 40 |
return gr.Dropdown(
|
| 41 |
choices=ds_split.column_names,
|
|
@@ -45,33 +45,35 @@ def get_columns(split: str):
|
|
| 45 |
)
|
| 46 |
|
| 47 |
|
| 48 |
-
def get_splits():
|
| 49 |
-
|
| 50 |
splits = list(ds.keys())
|
| 51 |
return gr.Dropdown(
|
| 52 |
choices=splits, value=splits[0], label="Select a split", visible=True
|
| 53 |
)
|
| 54 |
|
| 55 |
|
| 56 |
-
|
|
|
|
| 57 |
gr.Info("Vectorizing dataset...")
|
| 58 |
-
|
| 59 |
-
global ds
|
| 60 |
df = ds[split].to_polars()
|
| 61 |
embeddings = model.encode(df[column].cast(str), max_length=512)
|
| 62 |
-
|
| 63 |
|
| 64 |
|
| 65 |
-
def run_query(query: str, column: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
try:
|
| 67 |
-
global df
|
| 68 |
-
|
| 69 |
vector = model.encode(query)
|
| 70 |
df_results = duckdb.sql(
|
| 71 |
query=f"""
|
| 72 |
SELECT *
|
| 73 |
FROM df
|
| 74 |
-
ORDER BY array_cosine_distance(
|
| 75 |
LIMIT 5
|
| 76 |
"""
|
| 77 |
).to_df()
|
|
@@ -134,6 +136,7 @@ with gr.Blocks() as demo:
|
|
| 134 |
query_input = gr.Textbox(label="Query", visible=False)
|
| 135 |
|
| 136 |
btn_run = gr.Button("Search", visible=False)
|
|
|
|
| 137 |
results_output = gr.Dataframe(label="Results", visible=False)
|
| 138 |
|
| 139 |
search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
|
|
@@ -143,23 +146,23 @@ with gr.Blocks() as demo:
|
|
| 143 |
).then(
|
| 144 |
fn=hide_components,
|
| 145 |
outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output],
|
| 146 |
-
).then(fn=get_splits, outputs=split_dropdown).then(
|
| 147 |
-
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
|
| 148 |
)
|
| 149 |
|
| 150 |
split_dropdown.change(
|
| 151 |
-
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
|
| 152 |
)
|
| 153 |
|
| 154 |
column_dropdown.change(
|
| 155 |
fn=partial_hide_components,
|
| 156 |
outputs=[query_input, btn_run, results_output],
|
| 157 |
-
).then(fn=
|
| 158 |
-
fn=show_components, outputs=[query_input, btn_run]
|
| 159 |
-
)
|
| 160 |
|
| 161 |
btn_run.click(
|
| 162 |
-
fn=run_query,
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
|
| 165 |
demo.launch()
|
|
|
|
| 1 |
+
from functools import lru_cache
|
| 2 |
+
|
| 3 |
import duckdb
|
| 4 |
import gradio as gr
|
| 5 |
import polars as pl
|
|
|
|
| 7 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
| 8 |
from model2vec import StaticModel
|
| 9 |
|
|
|
|
| 10 |
global df
|
| 11 |
|
| 12 |
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
|
|
|
|
| 29 |
return iframe
|
| 30 |
|
| 31 |
|
| 32 |
+
def load_dataset_from_hub(hub_repo_id: str):
|
| 33 |
+
gr.Info(message="Loading dataset...")
|
|
|
|
| 34 |
ds = load_dataset(hub_repo_id)
|
| 35 |
|
| 36 |
|
| 37 |
+
def get_columns(hub_repo_id: str, split: str):
|
| 38 |
+
ds = load_dataset(hub_repo_id)
|
| 39 |
ds_split = ds[split]
|
| 40 |
return gr.Dropdown(
|
| 41 |
choices=ds_split.column_names,
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
|
| 48 |
+
def get_splits(hub_repo_id: str):
|
| 49 |
+
ds = load_dataset(hub_repo_id)
|
| 50 |
splits = list(ds.keys())
|
| 51 |
return gr.Dropdown(
|
| 52 |
choices=splits, value=splits[0], label="Select a split", visible=True
|
| 53 |
)
|
| 54 |
|
| 55 |
|
| 56 |
+
@lru_cache
|
| 57 |
+
def vectorize_dataset(hub_repo_id: str, split: str, column: str):
|
| 58 |
gr.Info("Vectorizing dataset...")
|
| 59 |
+
ds = load_dataset(hub_repo_id)
|
|
|
|
| 60 |
df = ds[split].to_polars()
|
| 61 |
embeddings = model.encode(df[column].cast(str), max_length=512)
|
| 62 |
+
return embeddings
|
| 63 |
|
| 64 |
|
| 65 |
+
def run_query(hub_repo_id: str, query: str, split: str, column: str):
|
| 66 |
+
embeddings = vectorize_dataset(hub_repo_id, split, column)
|
| 67 |
+
ds = load_dataset(hub_repo_id)
|
| 68 |
+
df = ds[split].to_polars()
|
| 69 |
+
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
| 70 |
try:
|
|
|
|
|
|
|
| 71 |
vector = model.encode(query)
|
| 72 |
df_results = duckdb.sql(
|
| 73 |
query=f"""
|
| 74 |
SELECT *
|
| 75 |
FROM df
|
| 76 |
+
ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
|
| 77 |
LIMIT 5
|
| 78 |
"""
|
| 79 |
).to_df()
|
|
|
|
| 136 |
query_input = gr.Textbox(label="Query", visible=False)
|
| 137 |
|
| 138 |
btn_run = gr.Button("Search", visible=False)
|
| 139 |
+
|
| 140 |
results_output = gr.Dataframe(label="Results", visible=False)
|
| 141 |
|
| 142 |
search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
|
|
|
|
| 146 |
).then(
|
| 147 |
fn=hide_components,
|
| 148 |
outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output],
|
| 149 |
+
).then(fn=get_splits, inputs=search_in, outputs=split_dropdown).then(
|
| 150 |
+
fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
|
| 151 |
)
|
| 152 |
|
| 153 |
split_dropdown.change(
|
| 154 |
+
fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
|
| 155 |
)
|
| 156 |
|
| 157 |
column_dropdown.change(
|
| 158 |
fn=partial_hide_components,
|
| 159 |
outputs=[query_input, btn_run, results_output],
|
| 160 |
+
).then(fn=show_components, outputs=[query_input, btn_run])
|
|
|
|
|
|
|
| 161 |
|
| 162 |
btn_run.click(
|
| 163 |
+
fn=run_query,
|
| 164 |
+
inputs=[search_in, query_input, split_dropdown, column_dropdown],
|
| 165 |
+
outputs=results_output,
|
| 166 |
)
|
| 167 |
|
| 168 |
demo.launch()
|