matinsn2000 commited on
Commit
d667f1f
Β·
1 Parent(s): 1cb8b50

Used better model for text embedding

Browse files
AI_USAGE_REPORT.txt CHANGED
@@ -18,8 +18,8 @@ WHERE & HOW AI WAS USED:
18
  - Function: Generate images from text prompts
19
 
20
  3. Semantic Search (cloudzy/search_engine.py + cloudzy/routes/search.py)
21
- - Tool: FAISS (vector database) with embeddings
22
- - Function: Find visually similar photos via embedding vectors
23
 
24
  PROMPTS & MODEL INPUTS:
25
  Image Analysis Prompt #1 - Structured Metadata (image_analyzer.py):
@@ -39,8 +39,10 @@ Search Queries:
39
  - Album creation: Groups similar photos by distance threshold (randomized each call)
40
 
41
  MODEL OUTPUTS REFINED:
42
- βœ“ JSON parsing: Extracted structured data from model text response
43
- βœ“ Distance threshold tuning: Adjusted for FAISS L2 distance (default 0.3)
 
 
44
  βœ“ Album randomization: Added random.shuffle() to prevent deterministic groupings
45
  βœ“ Error handling: Wrapped API failures to graceful fallbacks
46
 
@@ -60,14 +62,22 @@ Manual Refinements (35%):
60
  - CORS middleware configuration
61
 
62
  KEY TECHNICAL DECISIONS:
63
- 1. Distance threshold = 0.3: Filters visually similar photos
64
- 2. Model choice: Qwen3-VL for balanced speed/quality
65
- 3. FLUX.1-dev: High-quality image generation over speed
66
- 4. Random album creation: Ensures different groupings per request
67
- 5. HuggingFace Hub: Leveraged pre-tuned models vs training custom
 
 
68
 
69
  FILES MODIFIED FOR IMPROVEMENTS:
70
- - search_engine.py: Added randomization + album count control
 
71
  - image_analyzer.py: JSON error handling for vision model output
72
- - image_analyzer_2.py: Agentic image analysis with Gemini-2.0-Flash for aesthetic descriptions
73
- - text_to_image.py: Timestamp-based filename collision prevention
 
 
 
 
 
 
18
  - Function: Generate images from text prompts
19
 
20
  3. Semantic Search (cloudzy/search_engine.py + cloudzy/routes/search.py)
21
+ - Tool: FAISS (vector database) with embeddings from Qwen/Qwen3-Embedding-8B (4096-dimensional)
22
+ - Function: Find visually similar photos via L2-normalized embedding vectors
23
 
24
  PROMPTS & MODEL INPUTS:
25
  Image Analysis Prompt #1 - Structured Metadata (image_analyzer.py):
 
39
  - Album creation: Groups similar photos by distance threshold (randomized each call)
40
 
41
  MODEL OUTPUTS REFINED:
42
+ βœ“ JSON parsing: Extracted structured data from model text response (with dict type-check for Gemini responses)
43
+ βœ“ Embedding model upgrade: Migrated from multilingual-e5-large (1024-d) to Qwen3-Embedding-8B (4096-d)
44
+ βœ“ L2 normalization: Added unit-vector normalization to embeddings for consistent distance calculations
45
+ βœ“ Distance threshold tuning: Adjusted for normalized embeddings (0.5 β†’ 1.0 for search, 0.3 β†’ 1.5 for albums)
46
  βœ“ Album randomization: Added random.shuffle() to prevent deterministic groupings
47
  βœ“ Error handling: Wrapped API failures to graceful fallbacks
48
 
 
62
  - CORS middleware configuration
63
 
64
  KEY TECHNICAL DECISIONS:
65
+ 1. Embedding model: Qwen3-Embedding-8B (4096-d) for better semantic understanding than smaller models
66
+ 2. L2 normalization: Ensures normalized distances (0-2 range) independent of embedding dimension
67
+ 3. Distance thresholds: search() ≀ 1.0, create_albums() ≀ 1.5 (optimized for normalized embeddings)
68
+ 4. Model choice: Qwen3-VL for balanced speed/quality in image analysis
69
+ 5. FLUX.1-dev: High-quality image generation over speed
70
+ 6. Random album creation: Ensures different groupings per request
71
+ 7. HuggingFace Hub: Leveraged pre-tuned models vs training custom
72
 
73
  FILES MODIFIED FOR IMPROVEMENTS:
74
+ - ai_utils.py: Added L2 normalization to both generate_embedding() and _embed_text() methods
75
+ - search_engine.py: Updated distance thresholds (0.5β†’1.0 search, 0.3β†’1.5 albums) for normalized embeddings
76
  - image_analyzer.py: JSON error handling for vision model output
77
+ - image_analyzer_2.py: Dict type-check for Gemini responses + agentic image analysis with Gemini-2.0-Flash
78
+ - text_to_image.py: Timestamp-based filename collision prevention
79
+
80
+ EMBEDDING UPGRADE SUMMARY:
81
+ Old: multilingual-e5-large (1024-dimensional, unnormalized)
82
+ New: Qwen/Qwen3-Embedding-8B (4096-dimensional, L2-normalized)
83
+ Benefit: Better semantic understanding + consistent distance calculations across query types
cloudzy/agents/image_analyzer_2.py CHANGED
@@ -97,29 +97,43 @@ result: {
97
 
98
  response = self.agent.run(prompt, images=[image])
99
 
100
- # Ensure response is a string
101
- response_text = str(response) if response is not None else ""
102
-
103
- # Extract JSON part from response
104
- # Look for the pattern: result: { ... } (or { ... if closing brace is missing)
105
- match = re.search(r'result:\s*(\{[\s\S]*)', response_text)
106
-
107
- if not match:
108
- raise ValueError(f"Could not find JSON in response: {response_text}")
109
-
110
- json_str = match.group(1)
111
-
112
- # If the extracted JSON doesn't end with }, try adding it
113
- if not json_str.rstrip().endswith("}"):
114
- print(f"[Warning] No closing brace found in JSON, attempting to add closing brace...")
115
- json_str = json_str + "}"
116
-
117
  try:
118
- # Parse the JSON string into a dictionary
119
- result_dict = json.loads(json_str)
120
- return result_dict
121
- except json.JSONDecodeError as e:
122
- raise ValueError(f"Failed to parse JSON from response: {json_str}\nError: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
 
125
  # Test with sample images
 
97
 
98
  response = self.agent.run(prompt, images=[image])
99
 
100
+ # If response is already a dict, return it directly
101
+ if isinstance(response, dict):
102
+ return response
103
+
104
+ # Safely convert to string, handling non-string types
105
+ if response is None:
106
+ text_content = ""
107
+ else:
108
+ text_content = str(response).strip()
109
+
110
+ if not text_content:
111
+ raise ValueError("Model returned empty response")
112
+
113
+ # Try to extract JSON-like dict from model output
 
 
 
114
  try:
115
+ if "{" not in text_content:
116
+ raise ValueError("Response does not contain valid JSON structure (missing opening brace)")
117
+
118
+ start = text_content.index("{")
119
+
120
+ # Try to find closing brace
121
+ if "}" not in text_content[start:]:
122
+ # No closing brace found, try adding one
123
+ print(f"[Warning] No closing brace found in response, attempting to add closing brace...")
124
+ json_str = text_content[start:] + "}"
125
+ else:
126
+ end = text_content.rindex("}") + 1
127
+ json_str = text_content[start:end]
128
+
129
+ result = json.loads(json_str)
130
+ return result
131
+ except ValueError as ve:
132
+ raise ValueError(f"Failed to parse model output: {text_content}\nError: {ve}")
133
+ except json.JSONDecodeError as je:
134
+ raise ValueError(f"Invalid JSON in model output: {text_content}\nError: {je}")
135
+ except Exception as e:
136
+ raise ValueError(f"Failed to parse model output: {text_content}\nError: {e}")
137
 
138
 
139
  # Test with sample images
cloudzy/ai_utils.py CHANGED
@@ -1,24 +1,28 @@
1
  import os
2
  import numpy as np
3
  from huggingface_hub import InferenceClient
 
 
4
 
5
  from dotenv import load_dotenv
6
  load_dotenv()
7
 
 
 
8
  class ImageEmbeddingGenerator:
9
- def __init__(self, model_name: str = "intfloat/multilingual-e5-large"):
10
  """
11
  Initialize the embedding generator with a Hugging Face model.
12
  """
13
  self.client = InferenceClient(
14
- provider="hf-inference",
15
  api_key=os.environ["HF_TOKEN_1"],
16
  )
17
  self.model_name = model_name
18
 
19
  def generate_embedding(self, tags: list[str], description: str, caption: str) -> np.ndarray:
20
  """
21
- Generate a 512-d embedding for an image using its tags, description, and caption.
22
 
23
  Args:
24
  tags: List of tags related to the image
@@ -26,7 +30,7 @@ class ImageEmbeddingGenerator:
26
  caption: Short caption for the image
27
 
28
  Returns:
29
- embedding: 1D numpy array of shape (512,)
30
  """
31
  # Combine text fields into a single string
32
  text = " ".join(tags) + " " + description + " " + caption
@@ -40,9 +44,15 @@ class ImageEmbeddingGenerator:
40
  # Convert to numpy array
41
  embedding = np.array(result, dtype=np.float32).reshape(-1)
42
 
43
- # Ensure shape is (512,)
44
- if embedding.shape[0] != 1024:
45
- raise ValueError(f"Expected embedding of size 512, got {embedding.shape[0]}")
 
 
 
 
 
 
46
 
47
  return embedding
48
 
@@ -50,6 +60,7 @@ class ImageEmbeddingGenerator:
50
  def _embed_text(self, text: str) -> np.ndarray:
51
  """
52
  Internal helper to call Hugging Face feature_extraction and return a numpy array.
 
53
  """
54
  result = self.client.feature_extraction(
55
  text,
@@ -57,11 +68,19 @@ class ImageEmbeddingGenerator:
57
  )
58
  embedding = np.array(result, dtype=np.float32).reshape(-1)
59
 
60
- if embedding.shape[0] != 1024:
61
- raise ValueError(f"Expected embedding of size 1024, got {embedding.shape[0]}")
 
 
 
 
 
 
62
  return embedding
63
 
64
 
 
 
65
  class TextSummarizer:
66
  def __init__(self, model_name: str = "facebook/bart-large-cnn"):
67
  """
 
1
  import os
2
  import numpy as np
3
  from huggingface_hub import InferenceClient
4
+ from typing import List, Dict, Tuple
5
+ import re
6
 
7
  from dotenv import load_dotenv
8
  load_dotenv()
9
 
10
+
11
+
12
  class ImageEmbeddingGenerator:
13
+ def __init__(self, model_name: str = "Qwen/Qwen3-Embedding-8B"):
14
  """
15
  Initialize the embedding generator with a Hugging Face model.
16
  """
17
  self.client = InferenceClient(
18
+ provider="nebius",
19
  api_key=os.environ["HF_TOKEN_1"],
20
  )
21
  self.model_name = model_name
22
 
23
  def generate_embedding(self, tags: list[str], description: str, caption: str) -> np.ndarray:
24
  """
25
+ Generate a 4096-d embedding for an image using its tags, description, and caption.
26
 
27
  Args:
28
  tags: List of tags related to the image
 
30
  caption: Short caption for the image
31
 
32
  Returns:
33
+ embedding: 1D numpy array of shape (4096,), normalized to unit length
34
  """
35
  # Combine text fields into a single string
36
  text = " ".join(tags) + " " + description + " " + caption
 
44
  # Convert to numpy array
45
  embedding = np.array(result, dtype=np.float32).reshape(-1)
46
 
47
+ # Ensure shape is (4096,)
48
+ if embedding.shape[0] != 4096:
49
+ raise ValueError(f"Expected embedding of size 4096, got {embedding.shape[0]}")
50
+
51
+ # Normalize to unit length (L2 normalization)
52
+ # This ensures distances stay consistent across models and dimensions
53
+ norm = np.linalg.norm(embedding)
54
+ if norm > 0:
55
+ embedding = embedding / norm
56
 
57
  return embedding
58
 
 
60
  def _embed_text(self, text: str) -> np.ndarray:
61
  """
62
  Internal helper to call Hugging Face feature_extraction and return a numpy array.
63
+ Embeddings are normalized to unit length for consistent distance calculations.
64
  """
65
  result = self.client.feature_extraction(
66
  text,
 
68
  )
69
  embedding = np.array(result, dtype=np.float32).reshape(-1)
70
 
71
+ if embedding.shape[0] != 4096:
72
+ raise ValueError(f"Expected embedding of size 4096, got {embedding.shape[0]}")
73
+
74
+ # Normalize to unit length (L2 normalization)
75
+ norm = np.linalg.norm(embedding)
76
+ if norm > 0:
77
+ embedding = embedding / norm
78
+
79
  return embedding
80
 
81
 
82
+
83
+
84
  class TextSummarizer:
85
  def __init__(self, model_name: str = "facebook/bart-large-cnn"):
86
  """
cloudzy/routes/photo.py CHANGED
@@ -89,7 +89,7 @@ async def get_albums(
89
  """
90
 
91
  search_engine = SearchEngine()
92
- albums_ids = search_engine.create_albums(top_k=top_k)
93
  APP_DOMAIN = os.getenv("APP_DOMAIN") or "http://127.0.0.1:8000/"
94
  summarizer = TextSummarizer()
95
 
 
89
  """
90
 
91
  search_engine = SearchEngine()
92
+ albums_ids = search_engine.create_albums_kmeans(top_k=top_k)
93
  APP_DOMAIN = os.getenv("APP_DOMAIN") or "http://127.0.0.1:8000/"
94
  summarizer = TextSummarizer()
95
 
cloudzy/routes/search.py CHANGED
@@ -21,56 +21,47 @@ async def search_photos(
21
  session: Session = Depends(get_session),
22
  ):
23
  """
24
- Semantic search for similar photos using FAISS.
25
-
26
- Converts query to embedding and finds most similar images.
27
-
28
  Args:
29
  q: Search query (used to generate embedding)
30
  top_k: Number of results to return (max 50)
31
-
32
- Returns: List of similar photos with distance scores
33
  """
34
 
35
  generator = ImageEmbeddingGenerator()
36
  query_embedding = generator._embed_text(q)
37
 
38
-
39
-
40
- # Search in FAISS
41
  search_engine = SearchEngine()
42
  search_results = search_engine.search(query_embedding, top_k=top_k)
43
-
44
-
45
  if not search_results:
46
  return SearchResponse(
47
  query=q,
48
  results=[],
49
  total_results=0,
50
  )
51
-
52
- APP_DOMAIN = os.getenv("APP_DOMAIN")
53
 
54
-
55
-
56
- # Fetch photo details from database
57
  result_objects = []
 
58
  for photo_id, distance in search_results:
59
  statement = select(Photo).where(Photo.id == photo_id)
60
  photo = session.exec(statement).first()
61
-
62
- if photo: # Only include if photo exists in DB
63
  result_objects.append(
64
  SearchResult(
65
  photo_id=photo.id,
66
  filename=photo.filename,
67
- image_url = f"{APP_DOMAIN}uploads/{photo.filename}",
68
  tags=photo.get_tags(),
69
  caption=photo.caption,
70
  distance=distance,
71
  )
72
  )
73
-
74
  return SearchResponse(
75
  query=q,
76
  results=result_objects,
 
21
  session: Session = Depends(get_session),
22
  ):
23
  """
24
+ Semantic search endpoint using FAISS.
25
+
 
 
26
  Args:
27
  q: Search query (used to generate embedding)
28
  top_k: Number of results to return (max 50)
29
+
30
+ Returns: List of similar photos
31
  """
32
 
33
  generator = ImageEmbeddingGenerator()
34
  query_embedding = generator._embed_text(q)
35
 
 
 
 
36
  search_engine = SearchEngine()
37
  search_results = search_engine.search(query_embedding, top_k=top_k)
38
+
 
39
  if not search_results:
40
  return SearchResponse(
41
  query=q,
42
  results=[],
43
  total_results=0,
44
  )
 
 
45
 
46
+ APP_DOMAIN = os.getenv("APP_DOMAIN")
 
 
47
  result_objects = []
48
+
49
  for photo_id, distance in search_results:
50
  statement = select(Photo).where(Photo.id == photo_id)
51
  photo = session.exec(statement).first()
52
+
53
+ if photo:
54
  result_objects.append(
55
  SearchResult(
56
  photo_id=photo.id,
57
  filename=photo.filename,
58
+ image_url=f"{APP_DOMAIN}uploads/{photo.filename}",
59
  tags=photo.get_tags(),
60
  caption=photo.caption,
61
  distance=distance,
62
  )
63
  )
64
+
65
  return SearchResponse(
66
  query=q,
67
  results=result_objects,
cloudzy/search_engine.py CHANGED
@@ -9,7 +9,7 @@ import random
9
  class SearchEngine:
10
  """FAISS-based search engine for image embeddings"""
11
 
12
- def __init__(self, dim: int = 1024, index_path: str = "faiss_index.bin"):
13
  self.dim = dim
14
  self.index_path = index_path
15
 
@@ -20,7 +20,7 @@ class SearchEngine:
20
  base_index = faiss.IndexFlatL2(dim)
21
  self.index = faiss.IndexIDMap(base_index)
22
 
23
- def create_albums(self, top_k: int = 5, distance_threshold: float = 0.3, album_size: int = 5) -> List[List[int]]:
24
  """
25
  Group similar images into albums (clusters).
26
 
@@ -28,9 +28,14 @@ class SearchEngine:
28
  Photos are marked as visited to avoid duplicate albums.
29
  Only includes photos within the distance threshold.
30
 
 
 
 
 
 
31
  Args:
32
  top_k: Number of albums to return
33
- distance_threshold: Maximum distance to consider photos as similar (default 0.3)
34
  album_size: How many similar photos to search for per album (default 5)
35
 
36
  Returns:
@@ -51,6 +56,20 @@ class SearchEngine:
51
  # Shuffle for randomization - different albums each call
52
  random.shuffle(all_ids)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  visited = set()
55
  albums = []
56
 
@@ -63,37 +82,80 @@ class SearchEngine:
63
  if photo_id in visited:
64
  continue
65
 
66
- # Get embedding from database
67
- session = SessionLocal()
68
- try:
69
- photo = session.exec(select(Photo).where(Photo.id == photo_id)).first()
70
- if not photo:
71
- continue
72
-
73
- embedding = photo.get_embedding()
74
- if not embedding:
75
- continue
76
-
77
- # Search for similar images
78
- query_embedding = np.array(embedding).reshape(1, -1).astype(np.float32)
79
- distances, ids = self.index.search(query_embedding, album_size)
80
-
81
- # Build album: collect similar photos that haven't been visited and are within threshold
82
- album = []
83
- for pid, distance in zip(ids[0], distances[0]):
84
- if pid != -1 and pid not in visited and distance <= distance_threshold:
85
- album.append(int(pid))
86
- visited.add(pid)
87
-
88
- # Add album if it has at least 1 photo
89
- if album:
90
- albums.append(album)
91
-
92
- finally:
93
- session.close()
94
 
95
  return albums
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None:
98
  """
99
  Add an embedding to the index.
@@ -120,7 +182,7 @@ class SearchEngine:
120
  top_k: Number of results to return
121
 
122
  Returns:
123
- List of (photo_id, distance) tuples with distance <= 0.5
124
  """
125
  self.load()
126
 
@@ -133,11 +195,14 @@ class SearchEngine:
133
  # Search in FAISS index
134
  distances, ids = self.index.search(query_embedding, top_k)
135
 
 
 
136
  # Filter invalid and distant results
 
137
  results = [
138
  (int(photo_id), float(distance))
139
  for photo_id, distance in zip(ids[0], distances[0])
140
- if photo_id != -1 and distance <= 0.5
141
  ]
142
 
143
  return results
 
9
  class SearchEngine:
10
  """FAISS-based search engine for image embeddings"""
11
 
12
+ def __init__(self, dim: int = 4096, index_path: str = "faiss_index.bin"):
13
  self.dim = dim
14
  self.index_path = index_path
15
 
 
20
  base_index = faiss.IndexFlatL2(dim)
21
  self.index = faiss.IndexIDMap(base_index)
22
 
23
+ def create_albums(self, top_k: int = 5, distance_threshold: float = 1.5, album_size: int = 5) -> List[List[int]]:
24
  """
25
  Group similar images into albums (clusters).
26
 
 
28
  Photos are marked as visited to avoid duplicate albums.
29
  Only includes photos within the distance threshold.
30
 
31
+ OPTIMIZATIONS:
32
+ - Batch retrieves all photos in ONE database query (not per-photo)
33
+ - Caches embeddings in memory during execution
34
+ - Single session for all DB operations
35
+
36
  Args:
37
  top_k: Number of albums to return
38
+ distance_threshold: Maximum distance to consider photos as similar (default 1.0 for normalized embeddings)
39
  album_size: How many similar photos to search for per album (default 5)
40
 
41
  Returns:
 
56
  # Shuffle for randomization - different albums each call
57
  random.shuffle(all_ids)
58
 
59
+ # βœ… OPTIMIZATION 1: Batch retrieve all photos in ONE query
60
+ session = SessionLocal()
61
+ try:
62
+ # Fetch all photos at once, not in a loop
63
+ photos_query = session.exec(select(Photo).where(Photo.id.in_(all_ids))).all()
64
+ # βœ… OPTIMIZATION 2: Cache embeddings in memory
65
+ embedding_cache = {}
66
+ for photo in photos_query:
67
+ embedding = photo.get_embedding()
68
+ if embedding:
69
+ embedding_cache[photo.id] = embedding
70
+ finally:
71
+ session.close()
72
+
73
  visited = set()
74
  albums = []
75
 
 
82
  if photo_id in visited:
83
  continue
84
 
85
+ # Skip if no embedding cached
86
+ if photo_id not in embedding_cache:
87
+ continue
88
+
89
+ # Get embedding from cache (not DB)
90
+ embedding = embedding_cache[photo_id]
91
+
92
+ # Search for similar images
93
+ query_embedding = np.array(embedding).reshape(1, -1).astype(np.float32)
94
+ distances, ids = self.index.search(query_embedding, album_size)
95
+
96
+ # Build album: collect similar photos that haven't been visited and are within threshold
97
+ album = []
98
+ for pid, distance in zip(ids[0], distances[0]):
99
+ if pid != -1 and pid not in visited and distance <= distance_threshold:
100
+ album.append(int(pid))
101
+ visited.add(pid)
102
+
103
+ # Add album if it has at least 1 photo
104
+ if album:
105
+ albums.append(album)
 
 
 
 
 
 
 
106
 
107
  return albums
108
 
109
+ def create_albums_kmeans(self, top_k: int = 5, seed: int = 42) -> List[List[int]]:
110
+ """
111
+ Group similar images into albums using FAISS k-means clustering.
112
+
113
+ This is a BETTER approach than nearest-neighbor grouping:
114
+ - Uses true k-means clustering instead of ad-hoc neighbor search
115
+ - All photos get assigned to a cluster (no "orphans")
116
+ - Deterministic results for same seed
117
+ - Much faster for large datasets
118
+
119
+ Args:
120
+ top_k: Number of clusters (albums) to create
121
+ seed: Random seed for reproducibility
122
+
123
+ Returns:
124
+ List of top_k albums, each album is a list of photo_ids
125
+ """
126
+ self.load()
127
+ if self.index.ntotal < top_k:
128
+ return []
129
+
130
+ # Get all photo IDs from FAISS index
131
+ id_map = self.index.id_map
132
+ all_ids = np.array([id_map.at(i) for i in range(id_map.size())], dtype=np.int64)
133
+
134
+ # Get all embeddings from the underlying index (IndexIDMap wraps the actual index)
135
+ underlying_index = faiss.downcast_index(self.index.index)
136
+ all_embeddings = underlying_index.reconstruct_n(0, self.index.ntotal).astype(np.float32)
137
+
138
+ # βœ… Run k-means clustering
139
+ kmeans = faiss.Kmeans(
140
+ d=self.dim,
141
+ k=top_k,
142
+ niter=20,
143
+ verbose=False,
144
+ seed=seed
145
+ )
146
+ kmeans.train(all_embeddings)
147
+
148
+ # Assign each embedding to nearest cluster
149
+ distances, cluster_assignments = kmeans.index.search(all_embeddings, 1)
150
+
151
+ # Group photos by cluster
152
+ albums = [[] for _ in range(top_k)]
153
+ for photo_id, cluster_id in zip(all_ids, cluster_assignments.flatten()):
154
+ albums[cluster_id].append(int(photo_id))
155
+
156
+ # Remove empty albums and return
157
+ return [album for album in albums if album]
158
+
159
  def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None:
160
  """
161
  Add an embedding to the index.
 
182
  top_k: Number of results to return
183
 
184
  Returns:
185
+ List of (photo_id, distance) tuples with distance <= 1.0 (normalized embeddings)
186
  """
187
  self.load()
188
 
 
195
  # Search in FAISS index
196
  distances, ids = self.index.search(query_embedding, top_k)
197
 
198
+ print(distances)
199
+
200
  # Filter invalid and distant results
201
+ # With normalized embeddings, L2 distance range is 0-2, threshold of 1.0 works well
202
  results = [
203
  (int(photo_id), float(distance))
204
  for photo_id, distance in zip(ids[0], distances[0])
205
+ if photo_id != -1 and distance <= 1.5
206
  ]
207
 
208
  return results