Spaces:
Build error
Build error
| from functools import lru_cache | |
| import duckdb | |
| import gradio as gr | |
| import polars as pl | |
| from datasets import load_dataset | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from model2vec import StaticModel | |
| global df | |
| # Load a model from the HuggingFace hub (in this case the potion-base-8M model) | |
| model_name = "minishlab/potion-base-8M" | |
| model = StaticModel.from_pretrained(model_name) | |
| def get_iframe(hub_repo_id): | |
| if not hub_repo_id: | |
| raise ValueError("Hub repo id is required") | |
| url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" | |
| iframe = f""" | |
| <iframe | |
| src="{url}" | |
| frameborder="0" | |
| width="100%" | |
| height="600px" | |
| ></iframe> | |
| """ | |
| return iframe | |
| def load_dataset_from_hub(hub_repo_id: str): | |
| gr.Info(message="Loading dataset...") | |
| ds = load_dataset(hub_repo_id) | |
| def get_columns(hub_repo_id: str, split: str): | |
| ds = load_dataset(hub_repo_id) | |
| ds_split = ds[split] | |
| return gr.Dropdown( | |
| choices=ds_split.column_names, | |
| value=ds_split.column_names[0], | |
| label="Select a column", | |
| visible=True, | |
| ) | |
| def get_splits(hub_repo_id: str): | |
| ds = load_dataset(hub_repo_id) | |
| splits = list(ds.keys()) | |
| return gr.Dropdown( | |
| choices=splits, value=splits[0], label="Select a split", visible=True | |
| ) | |
| def vectorize_dataset(hub_repo_id: str, split: str, column: str): | |
| gr.Info("Vectorizing dataset...") | |
| ds = load_dataset(hub_repo_id) | |
| df = ds[split].to_polars() | |
| embeddings = model.encode(df[column].cast(str), max_length=512) | |
| return embeddings | |
| def run_query(hub_repo_id: str, query: str, split: str, column: str): | |
| embeddings = vectorize_dataset(hub_repo_id, split, column) | |
| ds = load_dataset(hub_repo_id) | |
| df = ds[split].to_polars() | |
| df = df.with_columns(pl.Series(embeddings).alias("embeddings")) | |
| try: | |
| vector = model.encode(query) | |
| df_results = duckdb.sql( | |
| query=f""" | |
| SELECT * | |
| FROM df | |
| ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256]) | |
| LIMIT 5 | |
| """ | |
| ).to_df() | |
| return gr.Dataframe(df_results, visible=True) | |
| except Exception as e: | |
| raise gr.Error(f"Error running query: {e}") | |
| def hide_components(): | |
| return [ | |
| gr.Dropdown(visible=False), | |
| gr.Dropdown(visible=False), | |
| gr.Textbox(visible=False), | |
| gr.Button(visible=False), | |
| gr.Dataframe(visible=False), | |
| ] | |
| def partial_hide_components(): | |
| return [ | |
| gr.Textbox(visible=False), | |
| gr.Button(visible=False), | |
| gr.Dataframe(visible=False), | |
| ] | |
| def show_components(): | |
| return [ | |
| gr.Textbox(visible=True, label="Query"), | |
| gr.Button(visible=True, value="Search"), | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.HTML( | |
| """ | |
| <h1>Vector Search any Hugging Face Dataset</h1> | |
| <p> | |
| This app allows you to vector search any Hugging Face dataset. | |
| You can search for the nearest neighbors of a query vector, or | |
| perform a similarity search on a dataframe. | |
| </p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| search_in = HuggingfaceHubSearch( | |
| label="Search Huggingface Hub", | |
| placeholder="Search for models on Huggingface", | |
| search_type="dataset", | |
| sumbit_on_select=True, | |
| ) | |
| with gr.Row(): | |
| search_out = gr.HTML(label="Search Results") | |
| with gr.Row(): | |
| split_dropdown = gr.Dropdown(label="Select a split", visible=False) | |
| column_dropdown = gr.Dropdown(label="Select a column", visible=False) | |
| with gr.Row(): | |
| query_input = gr.Textbox(label="Query", visible=False) | |
| btn_run = gr.Button("Search", visible=False) | |
| results_output = gr.Dataframe(label="Results", visible=False) | |
| search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then( | |
| fn=load_dataset_from_hub, | |
| inputs=search_in, | |
| show_progress=True, | |
| ).then( | |
| fn=hide_components, | |
| outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output], | |
| ).then(fn=get_splits, inputs=search_in, outputs=split_dropdown).then( | |
| fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown | |
| ) | |
| split_dropdown.change( | |
| fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown | |
| ) | |
| column_dropdown.change( | |
| fn=partial_hide_components, | |
| outputs=[query_input, btn_run, results_output], | |
| ).then(fn=show_components, outputs=[query_input, btn_run]) | |
| btn_run.click( | |
| fn=run_query, | |
| inputs=[search_in, query_input, split_dropdown, column_dropdown], | |
| outputs=results_output, | |
| ) | |
| demo.launch() | |