"""FAISS-based semantic search engine using ID-mapped index""" import faiss import numpy as np from typing import List, Tuple import os import random class SearchEngine: """FAISS-based search engine for image embeddings""" def __init__(self, dim: int = 4096, index_path: str = "faiss_index.bin"): self.dim = dim self.index_path = index_path # Load existing index or create a new one if os.path.exists(index_path): self.index = faiss.read_index(index_path) else: base_index = faiss.IndexFlatL2(dim) self.index = faiss.IndexIDMap(base_index) def create_albums(self, top_k: int = 5, distance_threshold: float = 1.5, album_size: int = 5) -> List[List[int]]: """ Group similar images into albums (clusters). Returns up to top_k albums, each containing up to album_size similar photos. Photos are marked as visited to avoid duplicate albums. Only includes photos within the distance threshold. Automatically adjusts if fewer images than requested albums. OPTIMIZATIONS: - Batch retrieves all photos in ONE database query (not per-photo) - Caches embeddings in memory during execution - Single session for all DB operations Args: top_k: Number of albums to return (returns fewer if not enough images) distance_threshold: Maximum distance to consider photos as similar (default 1.0 for normalized embeddings) album_size: How many similar photos to search for per album (default 5) Returns: List of up to top_k albums, each album is a list of photo_ids (randomized order each call) Returns empty list if no images exist. """ from cloudzy.database import SessionLocal from cloudzy.models import Photo from sqlmodel import select self.load() if self.index.ntotal == 0: return [] # Get all photo IDs from FAISS index id_map = self.index.id_map all_ids = [id_map.at(i) for i in range(id_map.size())] # Shuffle for randomization - different albums each call random.shuffle(all_ids) # ✅ OPTIMIZATION 1: Batch retrieve all photos in ONE query session = SessionLocal() try: # Fetch all photos at once, not in a loop photos_query = session.exec(select(Photo).where(Photo.id.in_(all_ids))).all() # ✅ OPTIMIZATION 2: Cache embeddings in memory embedding_cache = {} for photo in photos_query: embedding = photo.get_embedding() if embedding: embedding_cache[photo.id] = embedding finally: session.close() visited = set() albums = [] for photo_id in all_ids: # Stop if we have enough albums if len(albums) >= top_k: break # Skip if already in an album if photo_id in visited: continue # Skip if no embedding cached if photo_id not in embedding_cache: continue # Get embedding from cache (not DB) embedding = embedding_cache[photo_id] # Search for similar images query_embedding = np.array(embedding).reshape(1, -1).astype(np.float32) distances, ids = self.index.search(query_embedding, album_size) # Build album: collect similar photos that haven't been visited and are within threshold album = [] for pid, distance in zip(ids[0], distances[0]): if pid != -1 and pid not in visited and distance <= distance_threshold: album.append(int(pid)) visited.add(pid) # Add album if it has at least 1 photo if album: albums.append(album) return albums def create_albums_kmeans(self, top_k: int = 5, seed: int = 42) -> List[List[int]]: """ Group similar images into albums using FAISS k-means clustering. This is a BETTER approach than nearest-neighbor grouping: - Uses true k-means clustering instead of ad-hoc neighbor search - All photos get assigned to a cluster (no "orphans") - Deterministic results for same seed - Much faster for large datasets - Automatically adjusts if fewer images than requested clusters Args: top_k: Number of clusters (albums) to create seed: Random seed for reproducibility Returns: List of albums, each album is a list of photo_ids. Returns up to top_k albums, or fewer if total images < top_k. Returns empty list if no images exist. """ self.load() if self.index.ntotal == 0: return [] # Adjust k to not exceed total number of images actual_k = min(top_k, self.index.ntotal) # Get all photo IDs from FAISS index id_map = self.index.id_map all_ids = np.array([id_map.at(i) for i in range(id_map.size())], dtype=np.int64) # Get all embeddings from the underlying index (IndexIDMap wraps the actual index) underlying_index = faiss.downcast_index(self.index.index) all_embeddings = underlying_index.reconstruct_n(0, self.index.ntotal).astype(np.float32) # ✅ Run k-means clustering with adjusted k kmeans = faiss.Kmeans( d=self.dim, k=actual_k, niter=20, verbose=False, seed=seed ) kmeans.train(all_embeddings) # Assign each embedding to nearest cluster distances, cluster_assignments = kmeans.index.search(all_embeddings, 1) # Group photos by cluster albums = [[] for _ in range(actual_k)] for photo_id, cluster_id in zip(all_ids, cluster_assignments.flatten()): albums[cluster_id].append(int(photo_id)) # Remove empty albums and return return [album for album in albums if album] def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None: """ Add an embedding to the index. Args: photo_id: Unique photo identifier embedding: 1D numpy array of shape (dim,) """ # Ensure embedding is float32 and correct shape embedding = embedding.astype(np.float32).reshape(1, -1) # Add embedding with its ID self.index.add_with_ids(embedding, np.array([photo_id], dtype=np.int64)) # Save index to disk self.save() def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]: """ Search for similar embeddings. Args: query_embedding: 1D numpy array of shape (dim,) top_k: Number of results to return Returns: List of (photo_id, distance) tuples with distance <= 1.0 (normalized embeddings) """ self.load() if self.index.ntotal == 0: return [] # Ensure query is float32 and correct shape query_embedding = query_embedding.astype(np.float32).reshape(1, -1) # Search in FAISS index distances, ids = self.index.search(query_embedding, top_k) print(distances) # Filter invalid and distant results # With normalized embeddings, L2 distance range is 0-2, threshold of 1.0 works well results = [ (int(photo_id), float(distance)) for photo_id, distance in zip(ids[0], distances[0]) if photo_id != -1 and distance <= 1.5 ] return results def save(self) -> None: """Save FAISS index to disk""" faiss.write_index(self.index, self.index_path) def load(self) -> None: """Load FAISS index from disk""" if os.path.exists(self.index_path): self.index = faiss.read_index(self.index_path) else: # Recreate empty ID-mapped index if missing base_index = faiss.IndexFlatL2(self.dim) self.index = faiss.IndexIDMap(base_index) def get_stats(self) -> dict: """Get index statistics""" return { "total_embeddings": self.index.ntotal, "dimension": self.dim, "index_type": type(self.index).__name__, }