Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| from PIL import Image | |
| import os | |
| from typing import List, Tuple, Optional | |
| import time | |
| # ============= DATASET SETUP FUNCTION ============= | |
| def setup_dataset(): | |
| """Download and prepare dataset if not exists.""" | |
| if not os.path.exists("dataset/images"): | |
| print("π₯ First-time setup: downloading dataset...") | |
| # Import required modules for setup | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| # Create directories | |
| os.makedirs("dataset/images", exist_ok=True) | |
| # 1. Download images (from download_images_hf.py) | |
| print("π₯ Loading Caltech101 dataset...") | |
| dataset = load_dataset("flwrlabs/caltech101", split="train") | |
| dataset = dataset.shuffle(seed=42).select(range(min(500, len(dataset)))) | |
| print(f"πΎ Saving {len(dataset)} images locally...") | |
| for i, item in enumerate(tqdm(dataset)): | |
| img = item["image"] | |
| label = item["label"] | |
| label_name = dataset.features["label"].int2str(label) | |
| fname = f"{i:05d}_{label_name}.jpg" | |
| img.save(os.path.join("dataset/images", fname)) | |
| # 2. Generate embeddings (from embed_images_clip.py) | |
| print("π Generating image embeddings...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = SentenceTransformer("clip-ViT-B-32", device=device) | |
| image_files = [f for f in os.listdir("dataset/images") if f.lower().endswith((".jpg", ".png"))] | |
| embeddings = [] | |
| for fname in tqdm(image_files, desc="Encoding images"): | |
| img_path = os.path.join("dataset/images", fname) | |
| img = Image.open(img_path).convert("RGB") | |
| emb = model.encode(img, convert_to_numpy=True, show_progress_bar=False, normalize_embeddings=True) | |
| embeddings.append(emb) | |
| embeddings = np.array(embeddings, dtype="float32") | |
| np.save("dataset/image_embeddings.npy", embeddings) | |
| np.save("dataset/image_filenames.npy", np.array(image_files)) | |
| # 3. Build FAISS index (from build_faiss_index.py) | |
| print("π¦ Building FAISS index...") | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| index.add(embeddings) | |
| faiss.write_index(index, "dataset/faiss_index.bin") | |
| print("β Dataset setup complete!") | |
| else: | |
| print("β Dataset found, ready to go!") | |
| # Call setup before anything else | |
| setup_dataset() | |
| # Configuration | |
| META_PATH = "dataset/image_filenames.npy" | |
| INDEX_PATH = "dataset/faiss_index.bin" | |
| IMG_DIR = "dataset/images" | |
| class MultimodalSearchEngine: | |
| def __init__(self): | |
| """Initialize the search engine with pre-built index and model.""" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"π Using device: {self.device}") | |
| # Load pre-built index and metadata | |
| self.index = faiss.read_index(INDEX_PATH) | |
| self.image_files = np.load(META_PATH) | |
| # Load CLIP model | |
| self.model = SentenceTransformer("clip-ViT-B-32", device=self.device) | |
| print(f"β Loaded index with {self.index.ntotal} images") | |
| def search_by_text(self, query: str, k: int = 5) -> List[Tuple[str, float]]: | |
| """Search for images matching a text query.""" | |
| if not query.strip(): | |
| return [] | |
| start_time = time.time() | |
| query_emb = self.model.encode([query], convert_to_numpy=True, normalize_embeddings=True) | |
| scores, idxs = self.index.search(query_emb, k) | |
| search_time = time.time() - start_time | |
| results = [] | |
| for j, i in enumerate(idxs[0]): | |
| if i != -1: # Valid index | |
| img_path = os.path.join(IMG_DIR, self.image_files[i]) | |
| results.append((img_path, float(scores[0][j]), search_time)) | |
| return results | |
| def search_by_image(self, image: Image.Image, k: int = 5) -> List[Tuple[str, float]]: | |
| """Search for images visually similar to the given image.""" | |
| if image is None: | |
| return [] | |
| start_time = time.time() | |
| # Convert to RGB if necessary | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| query_emb = self.model.encode(image, convert_to_numpy=True, normalize_embeddings=True) | |
| query_emb = np.expand_dims(query_emb, axis=0) | |
| scores, idxs = self.index.search(query_emb, k) | |
| search_time = time.time() - start_time | |
| results = [] | |
| for j, i in enumerate(idxs[0]): | |
| if i != -1: # Valid index | |
| img_path = os.path.join(IMG_DIR, self.image_files[i]) | |
| results.append((img_path, float(scores[0][j]), search_time)) | |
| return results | |
| # Initialize the search engine | |
| try: | |
| search_engine = MultimodalSearchEngine() | |
| ENGINE_LOADED = True | |
| except Exception as e: | |
| print(f"β Error loading search engine: {e}") | |
| ENGINE_LOADED = False | |
| def format_results(results: List[Tuple[str, float, float]]) -> Tuple[List[str], str]: | |
| """Format search results for Gradio display.""" | |
| if not results: | |
| return [], "No results found." | |
| image_paths = [result[0] for result in results] | |
| search_time = results[0][2] if results else 0 | |
| # Create detailed results text | |
| results_text = f"π **Search Results** (Search time: {search_time:.3f}s)\n\n" | |
| for i, (path, score, _) in enumerate(results, 1): | |
| filename = os.path.basename(path) | |
| # Extract label from filename (format: 00000_label.jpg) | |
| label = filename.split('_', 1)[1].rsplit('.', 1)[0] if '_' in filename else 'unknown' | |
| results_text += f"**{i}.** {label} (similarity: {score:.3f})\n" | |
| return image_paths, results_text | |
| def text_search_interface(query: str, num_results: int) -> Tuple[List[str], str]: | |
| """Interface function for text-based search.""" | |
| if not ENGINE_LOADED: | |
| return [], "β Search engine not loaded. Please check if all files are available." | |
| if not query.strip(): | |
| return [], "Please enter a search query." | |
| try: | |
| results = search_engine.search_by_text(query, k=num_results) | |
| return format_results(results) | |
| except Exception as e: | |
| return [], f"β Error during search: {str(e)}" | |
| def image_search_interface(image: Image.Image, num_results: int) -> Tuple[List[str], str]: | |
| """Interface function for image-based search.""" | |
| if not ENGINE_LOADED: | |
| return [], "β Search engine not loaded. Please check if all files are available." | |
| if image is None: | |
| return [], "Please upload an image." | |
| try: | |
| results = search_engine.search_by_image(image, k=num_results) | |
| return format_results(results) | |
| except Exception as e: | |
| return [], f"β Error during search: {str(e)}" | |
| def get_random_examples() -> List[str]: | |
| """Get random example queries.""" | |
| examples = [ | |
| "a cat sitting on a chair", | |
| "airplane in the sky", | |
| "red car on the road", | |
| "person playing guitar", | |
| "dog running in the park", | |
| "beautiful sunset landscape", | |
| "computer on a desk", | |
| "flowers in a garden" | |
| ] | |
| return examples | |
| # Create the Gradio interface | |
| with gr.Blocks( | |
| title="π Multimodal AI Search Engine", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .gallery img { | |
| border-radius: 8px; | |
| } | |
| """ | |
| ) as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 30px;"> | |
| <h1>π Multimodal AI Search Engine</h1> | |
| <p style="font-size: 18px; color: #666;"> | |
| Search through 500 Caltech101 images using text descriptions or image similarity | |
| </p> | |
| <p style="font-size: 14px; color: #888;"> | |
| Powered by CLIP-ViT-B-32 and FAISS for fast similarity search | |
| </p> | |
| </div> | |
| """) | |
| with gr.Tabs() as tabs: | |
| # Text Search Tab | |
| with gr.Tab("π Text Search", id="text_search"): | |
| gr.Markdown("### Search images using natural language descriptions") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_query = gr.Textbox( | |
| label="Search Query", | |
| placeholder="Describe what you're looking for (e.g., 'a red car', 'person with guitar')", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| text_num_results = gr.Slider( | |
| minimum=1, maximum=20, value=5, step=1, | |
| label="Number of Results" | |
| ) | |
| text_search_btn = gr.Button("π Search", variant="primary", size="lg") | |
| # Examples | |
| gr.Examples( | |
| examples=get_random_examples()[:4], | |
| inputs=text_query, | |
| label="Example Queries" | |
| ) | |
| with gr.Row(): | |
| text_results = gr.Gallery( | |
| label="Search Results", | |
| show_label=True, | |
| elem_id="text_gallery", | |
| columns=5, | |
| rows=1, | |
| height="auto", | |
| object_fit="contain" | |
| ) | |
| text_info = gr.Markdown(label="Details") | |
| # Image Search Tab | |
| with gr.Tab("πΌοΈ Image Search", id="image_search"): | |
| gr.Markdown("### Find visually similar images") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_query = gr.Image( | |
| label="Upload Query Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| with gr.Column(scale=1): | |
| image_num_results = gr.Slider( | |
| minimum=1, maximum=20, value=5, step=1, | |
| label="Number of Results" | |
| ) | |
| image_search_btn = gr.Button("π Search Similar", variant="primary", size="lg") | |
| with gr.Row(): | |
| image_results = gr.Gallery( | |
| label="Similar Images", | |
| show_label=True, | |
| elem_id="image_gallery", | |
| columns=5, | |
| rows=1, | |
| height="auto", | |
| object_fit="contain" | |
| ) | |
| image_info = gr.Markdown(label="Details") | |
| # About Tab | |
| with gr.Tab("βΉοΈ About", id="about"): | |
| gr.Markdown(""" | |
| ### π¬ Technical Details | |
| This multimodal search engine demonstrates advanced AI techniques for content-based image retrieval: | |
| **π§ Model Architecture:** | |
| - **CLIP-ViT-B-32**: OpenAI's Contrastive Language-Image Pre-training model | |
| - **Vision Transformer**: Processes images using attention mechanisms | |
| - **Dual-encoder**: Separate encoders for text and images mapping to shared embedding space | |
| **β‘ Search Infrastructure:** | |
| - **FAISS**: Facebook AI Similarity Search for efficient vector operations | |
| - **Cosine Similarity**: Measures semantic similarity in embedding space | |
| - **Inner Product Index**: Optimized for normalized embeddings | |
| **π Dataset:** | |
| - **Caltech101**: 500 randomly sampled images from 101 object categories | |
| - **Preprocessing**: RGB conversion, CLIP-compatible normalization | |
| - **Embeddings**: 512-dimensional feature vectors per image | |
| **π Performance Features:** | |
| - **GPU Acceleration**: CUDA support for faster inference | |
| - **Batch Processing**: Efficient embedding computation | |
| - **Real-time Search**: Sub-second query response times | |
| - **Normalized Embeddings**: L2 normalization for consistent similarity scores | |
| **π― Applications:** | |
| - Content-based image retrieval | |
| - Visual search engines | |
| - Cross-modal similarity matching | |
| - Dataset exploration and analysis | |
| ### π οΈ Implementation Highlights | |
| - Modular architecture with separate indexing and search components | |
| - Error handling and graceful degradation | |
| - Configurable result counts and similarity thresholds | |
| - Professional UI with responsive design | |
| """) | |
| # Event handlers | |
| text_search_btn.click( | |
| fn=text_search_interface, | |
| inputs=[text_query, text_num_results], | |
| outputs=[text_results, text_info] | |
| ) | |
| image_search_btn.click( | |
| fn=image_search_interface, | |
| inputs=[image_query, image_num_results], | |
| outputs=[image_results, image_info] | |
| ) | |
| # Auto-search on Enter key for text | |
| text_query.submit( | |
| fn=text_search_interface, | |
| inputs=[text_query, text_num_results], | |
| outputs=[text_results, text_info] | |
| ) | |
| # Launch configuration | |
| if __name__ == "__main__": | |
| print("\n" + "="*50) | |
| print("π Starting Multimodal AI Search Engine") | |
| print("="*50) | |
| if ENGINE_LOADED: | |
| print(f"β Search engine ready with {search_engine.index.ntotal} images") | |
| print(f"β Using device: {search_engine.device}") | |
| else: | |
| print("β Search engine failed to load") | |
| print("\nπ‘ Usage Tips:") | |
| print("- Text search: Use natural language descriptions") | |
| print("- Image search: Upload any image to find similar ones") | |
| print("- Adjust result count using the slider") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |