cloudzy_ai_challenge / cloudzy /search_engine.py
matinsn2000's picture
Fixed the problem with k means clustring
c2cd7f1
"""FAISS-based semantic search engine using ID-mapped index"""
import faiss
import numpy as np
from typing import List, Tuple
import os
import random
class SearchEngine:
"""FAISS-based search engine for image embeddings"""
def __init__(self, dim: int = 4096, index_path: str = "faiss_index.bin"):
self.dim = dim
self.index_path = index_path
# Load existing index or create a new one
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
else:
base_index = faiss.IndexFlatL2(dim)
self.index = faiss.IndexIDMap(base_index)
def create_albums(self, top_k: int = 5, distance_threshold: float = 1.5, album_size: int = 5) -> List[List[int]]:
"""
Group similar images into albums (clusters).
Returns up to top_k albums, each containing up to album_size similar photos.
Photos are marked as visited to avoid duplicate albums.
Only includes photos within the distance threshold.
Automatically adjusts if fewer images than requested albums.
OPTIMIZATIONS:
- Batch retrieves all photos in ONE database query (not per-photo)
- Caches embeddings in memory during execution
- Single session for all DB operations
Args:
top_k: Number of albums to return (returns fewer if not enough images)
distance_threshold: Maximum distance to consider photos as similar (default 1.0 for normalized embeddings)
album_size: How many similar photos to search for per album (default 5)
Returns:
List of up to top_k albums, each album is a list of photo_ids (randomized order each call)
Returns empty list if no images exist.
"""
from cloudzy.database import SessionLocal
from cloudzy.models import Photo
from sqlmodel import select
self.load()
if self.index.ntotal == 0:
return []
# Get all photo IDs from FAISS index
id_map = self.index.id_map
all_ids = [id_map.at(i) for i in range(id_map.size())]
# Shuffle for randomization - different albums each call
random.shuffle(all_ids)
# ✅ OPTIMIZATION 1: Batch retrieve all photos in ONE query
session = SessionLocal()
try:
# Fetch all photos at once, not in a loop
photos_query = session.exec(select(Photo).where(Photo.id.in_(all_ids))).all()
# ✅ OPTIMIZATION 2: Cache embeddings in memory
embedding_cache = {}
for photo in photos_query:
embedding = photo.get_embedding()
if embedding:
embedding_cache[photo.id] = embedding
finally:
session.close()
visited = set()
albums = []
for photo_id in all_ids:
# Stop if we have enough albums
if len(albums) >= top_k:
break
# Skip if already in an album
if photo_id in visited:
continue
# Skip if no embedding cached
if photo_id not in embedding_cache:
continue
# Get embedding from cache (not DB)
embedding = embedding_cache[photo_id]
# Search for similar images
query_embedding = np.array(embedding).reshape(1, -1).astype(np.float32)
distances, ids = self.index.search(query_embedding, album_size)
# Build album: collect similar photos that haven't been visited and are within threshold
album = []
for pid, distance in zip(ids[0], distances[0]):
if pid != -1 and pid not in visited and distance <= distance_threshold:
album.append(int(pid))
visited.add(pid)
# Add album if it has at least 1 photo
if album:
albums.append(album)
return albums
def create_albums_kmeans(self, top_k: int = 5, seed: int = 42) -> List[List[int]]:
"""
Group similar images into albums using FAISS k-means clustering.
This is a BETTER approach than nearest-neighbor grouping:
- Uses true k-means clustering instead of ad-hoc neighbor search
- All photos get assigned to a cluster (no "orphans")
- Deterministic results for same seed
- Much faster for large datasets
- Automatically adjusts if fewer images than requested clusters
Args:
top_k: Number of clusters (albums) to create
seed: Random seed for reproducibility
Returns:
List of albums, each album is a list of photo_ids.
Returns up to top_k albums, or fewer if total images < top_k.
Returns empty list if no images exist.
"""
self.load()
if self.index.ntotal == 0:
return []
# Adjust k to not exceed total number of images
actual_k = min(top_k, self.index.ntotal)
# Get all photo IDs from FAISS index
id_map = self.index.id_map
all_ids = np.array([id_map.at(i) for i in range(id_map.size())], dtype=np.int64)
# Get all embeddings from the underlying index (IndexIDMap wraps the actual index)
underlying_index = faiss.downcast_index(self.index.index)
all_embeddings = underlying_index.reconstruct_n(0, self.index.ntotal).astype(np.float32)
# ✅ Run k-means clustering with adjusted k
kmeans = faiss.Kmeans(
d=self.dim,
k=actual_k,
niter=20,
verbose=False,
seed=seed
)
kmeans.train(all_embeddings)
# Assign each embedding to nearest cluster
distances, cluster_assignments = kmeans.index.search(all_embeddings, 1)
# Group photos by cluster
albums = [[] for _ in range(actual_k)]
for photo_id, cluster_id in zip(all_ids, cluster_assignments.flatten()):
albums[cluster_id].append(int(photo_id))
# Remove empty albums and return
return [album for album in albums if album]
def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None:
"""
Add an embedding to the index.
Args:
photo_id: Unique photo identifier
embedding: 1D numpy array of shape (dim,)
"""
# Ensure embedding is float32 and correct shape
embedding = embedding.astype(np.float32).reshape(1, -1)
# Add embedding with its ID
self.index.add_with_ids(embedding, np.array([photo_id], dtype=np.int64))
# Save index to disk
self.save()
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]:
"""
Search for similar embeddings.
Args:
query_embedding: 1D numpy array of shape (dim,)
top_k: Number of results to return
Returns:
List of (photo_id, distance) tuples with distance <= 1.0 (normalized embeddings)
"""
self.load()
if self.index.ntotal == 0:
return []
# Ensure query is float32 and correct shape
query_embedding = query_embedding.astype(np.float32).reshape(1, -1)
# Search in FAISS index
distances, ids = self.index.search(query_embedding, top_k)
print(distances)
# Filter invalid and distant results
# With normalized embeddings, L2 distance range is 0-2, threshold of 1.0 works well
results = [
(int(photo_id), float(distance))
for photo_id, distance in zip(ids[0], distances[0])
if photo_id != -1 and distance <= 1.5
]
return results
def save(self) -> None:
"""Save FAISS index to disk"""
faiss.write_index(self.index, self.index_path)
def load(self) -> None:
"""Load FAISS index from disk"""
if os.path.exists(self.index_path):
self.index = faiss.read_index(self.index_path)
else:
# Recreate empty ID-mapped index if missing
base_index = faiss.IndexFlatL2(self.dim)
self.index = faiss.IndexIDMap(base_index)
def get_stats(self) -> dict:
"""Get index statistics"""
return {
"total_embeddings": self.index.ntotal,
"dimension": self.dim,
"index_type": type(self.index).__name__,
}