Spaces:
Runtime error
Runtime error
| import glob | |
| import gradio as gr | |
| import pandas as pd | |
| import faiss | |
| import clip | |
| import torch | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| title = r""" | |
| <h1 align="center" id="space-title"> 🔍 Search Similar Text/Image in the Dataset</h1> | |
| """ | |
| description = r""" | |
| Find text or images similar to your query text with this demo. Currently, it supports text search only.<br> | |
| In this demo, we use a subset of [danbooru22](https://huggingface.co/datasets/animelover/danbooru2022) or [DiffusionDB](https://huggingface.co/datasets/poloclub/diffusiondb) instead of [LAION](https://laion.ai/blog/laion-400-open-dataset/) because LAION is currently not available. | |
| <br> | |
| The content will be updated to include image search once LAION is available. | |
| The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval) and [autofaiss](https://github.com/criteo/autofaiss) | |
| """ | |
| # From local file | |
| # INDEX_DIR = "dataset/diffusiondb/text_index_folder" | |
| # IND = faiss.read_index(f"{INDEX_DIR}/text.index") | |
| # TEXT_LIST = pd.concat( | |
| # pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet") | |
| # )['caption'].tolist() | |
| def download_all_index(dataset_dict): | |
| for k in dataset_dict: | |
| load_faiss_index(k) | |
| def load_faiss_index(dataset): | |
| index_dir = "data/faiss_index" | |
| dataset = DATASET_NAME[dataset] | |
| hf_hub_download( | |
| repo_id="Eun02/text_image_faiss_index", | |
| subfolder=dataset, | |
| filename="text.index", | |
| repo_type="dataset", | |
| local_dir=index_dir, | |
| ) | |
| # Download text file | |
| snapshot_download( | |
| repo_id="Eun02/text_image_faiss_index", | |
| allow_patterns=f"{dataset}/*.parquet", | |
| repo_type="dataset", | |
| local_dir=index_dir, | |
| ) | |
| index = faiss.read_index(f"{index_dir}/{dataset}/text.index") | |
| text_list = pd.concat( | |
| pd.read_parquet(file) for file in sorted(glob.glob(f"{index_dir}/{dataset}/metadata/*.parquet")) | |
| )['caption'].tolist() | |
| return index, text_list | |
| def change_index(dataset): | |
| global INDEX, TEXT_LIST, PREV_DATASET | |
| if PREV_DATASET != dataset: | |
| gr.Info("Load index...") | |
| INDEX, TEXT_LIST = load_faiss_index(dataset) | |
| PREV_DATASET = dataset | |
| gr.Info("Done!!") | |
| return None | |
| def get_emb(text, device="cpu"): | |
| text_tokens = clip.tokenize([text], truncate=True) | |
| text_features = CLIP_MODEL.encode_text(text_tokens.to(device)) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| text_embeddings = text_features.cpu().numpy().astype('float32') | |
| return text_embeddings | |
| def search_text(top_k, show_score, numbering_prefix, output_file, query_text): | |
| if query_text is None or query_text == "": | |
| raise gr.Error("Query text is missing") | |
| text_embeddings = get_emb(query_text, device) | |
| scores, retrieved_texts = INDEX.search(text_embeddings, top_k) | |
| scores, retrieved_texts = scores[0], retrieved_texts[0] | |
| result_list = [] | |
| for score, ind in zip(scores, retrieved_texts): | |
| item_str = TEXT_LIST[ind].strip() | |
| if item_str == "": | |
| continue | |
| if (item_str, score) not in result_list: | |
| result_list.append((item_str, score)) | |
| # Postprocessing text | |
| result_str = "" | |
| for count, (item_str, score) in enumerate(result_list): | |
| if numbering_prefix: | |
| item_str = f"###################### {count+1} ######################\n {item_str}" | |
| if show_score: | |
| item_str += f", {score:0.2f}" | |
| result_str += f"{item_str}\n" | |
| # file_name = query_text.replace(" ", "_") | |
| # if show_score: | |
| # file_name += "_score" | |
| output_path = None | |
| if output_file: | |
| file_name = "output" | |
| output_path = f"./{file_name}.txt" | |
| with open(output_path, "w") as f: | |
| f.writelines(result_str) | |
| return result_str, output_path | |
| # Load CLIP model | |
| device = "cpu" | |
| CLIP_MODEL, _ = clip.load("ViT-B/32", device=device) | |
| # Dataset | |
| DATASET_NAME = { | |
| "danbooru22": "booru22_000-300", | |
| "DiffusionDB": "diffusiondb", | |
| } | |
| DEFAULT_DATASET = "danbooru22" | |
| PREV_DATASET = "danbooru22" | |
| # Download needed index | |
| download_all_index(DATASET_NAME) | |
| # Load default index | |
| INDEX, TEXT_LIST = load_faiss_index(DEFAULT_DATASET) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| dataset = gr.Dropdown(label="dataset", choices=["danbooru22", "DiffusionDB"], value=DEFAULT_DATASET) | |
| top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8) | |
| with gr.Column(): | |
| show_score = gr.Checkbox(label="Show score", value=False) | |
| numbering_prefix = gr.Checkbox(label="Add numbering prefix", value=True) | |
| output_file = gr.Checkbox(label="Return text file", value=True) | |
| query_text = gr.Textbox(label="query text") | |
| btn = gr.Button() | |
| result_text = gr.Textbox(label="retrieved text", interactive=False) | |
| result_file = gr.File(label="output file", visible=True) | |
| #dataset.change(change_index, dataset, None) | |
| btn.click( | |
| fn=change_index, | |
| inputs=[dataset], | |
| outputs=[result_text], | |
| ).success( | |
| fn=search_text, | |
| inputs=[top_k, show_score, numbering_prefix, output_file, query_text], | |
| outputs=[result_text, result_file], | |
| ) | |
| demo.launch() | |