matinsn2000 commited on
Commit
cbab173
·
1 Parent(s): f1911ec

Added image embedding as playground and roll backed for create_album end point to not use k mean clustring

Browse files
cloudzy/embedding/image_embedding.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoProcessor
2
+ from PIL import Image
3
+ import requests
4
+ import numpy as np
5
+ import torch
6
+ from io import BytesIO
7
+
8
+ # Load model and processor directly
9
+ model = AutoModel.from_pretrained("jinaai/jina-clip-v2", trust_remote_code=True)
10
+ processor = AutoProcessor.from_pretrained("jinaai/jina-clip-v2", trust_remote_code=True)
11
+
12
+ texts = ["Woman taking pictures on a road trip.", "delicious fruits glowing under sunlight"]
13
+ # Process and encode text
14
+ text_inputs = processor(text=texts, return_tensors="pt", padding=True)
15
+ with torch.no_grad():
16
+ text_embeddings = model.get_text_features(**text_inputs)
17
+ text_embeddings = text_embeddings.cpu().numpy()
18
+ print("Text embeddings shape:", text_embeddings.shape)
19
+
20
+ image_paths = [
21
+ "/Users/komeilfathi/Documents/hf_deploy_test/cloudzy_ai_challenge/uploads/img_1_20251026_014959_886.jpg",
22
+ "/Users/komeilfathi/Documents/hf_deploy_test/cloudzy_ai_challenge/uploads/img_9_20251024_185602_319.webp"
23
+ ]
24
+ images = []
25
+ for path in image_paths:
26
+ try:
27
+ img = Image.open(path).convert("RGB")
28
+ images.append(img)
29
+ print(f"✓ Loaded image from {path}")
30
+ except Exception as e:
31
+ print(f"✗ Failed to load image from {path}: {e}")
32
+
33
+ # Process and encode images
34
+ if images:
35
+ image_inputs = processor(images=images, return_tensors="pt")
36
+ with torch.no_grad():
37
+ image_embeddings = model.get_image_features(**image_inputs)
38
+ image_embeddings = image_embeddings.cpu().numpy()
39
+ print("Image embeddings shape:", image_embeddings.shape)
40
+ else:
41
+ print("⚠ No images loaded successfully")
42
+ image_embeddings = np.array([])
43
+
44
+ def cosine_similarity(a, b):
45
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
46
+
47
+ if len(image_embeddings) > 0:
48
+ for i, t_emb in enumerate(text_embeddings):
49
+ for j, i_emb in enumerate(image_embeddings):
50
+ sim = cosine_similarity(t_emb, i_emb)
51
+ print(f"Similarity between text {i} and image {j}: {sim:.4f}")
52
+ else:
53
+ print("No images to compare similarity with")
cloudzy/models.py CHANGED
@@ -12,6 +12,7 @@ class Photo(SQLModel, table=True):
12
  filepath: str # Full path to stored image
13
  tags: str = Field(default="[]") # JSON string of tags
14
  caption: str = Field(default="")
 
15
  embedding: Optional[str] = Field(default=None) # JSON string of embedding vector
16
  created_at: datetime = Field(default_factory=datetime.utcnow)
17
 
 
12
  filepath: str # Full path to stored image
13
  tags: str = Field(default="[]") # JSON string of tags
14
  caption: str = Field(default="")
15
+ description: str = Field(default="")
16
  embedding: Optional[str] = Field(default=None) # JSON string of embedding vector
17
  created_at: datetime = Field(default_factory=datetime.utcnow)
18
 
cloudzy/routes/photo.py CHANGED
@@ -37,7 +37,7 @@ async def get_photo(
37
  image_url = f"{APP_DOMAIN}uploads/{photo.filename}",
38
  tags=photo.get_tags(),
39
  caption=photo.caption,
40
- # embedding=photo.get_embedding(),
41
  created_at=photo.created_at,
42
  )
43
 
@@ -72,7 +72,7 @@ async def list_photos(
72
  image_url = f"{APP_DOMAIN}uploads/{photo.filename}",
73
  tags=photo.get_tags(),
74
  caption=photo.caption,
75
- # embedding=photo.get_embedding(),
76
  created_at=photo.created_at,
77
  )
78
  for photo in photos
@@ -89,7 +89,7 @@ async def get_albums(
89
  """
90
 
91
  search_engine = SearchEngine()
92
- albums_ids = search_engine.create_albums_kmeans(top_k=top_k)
93
  APP_DOMAIN = os.getenv("APP_DOMAIN") or "http://127.0.0.1:8000/"
94
  summarizer = TextSummarizer()
95
 
@@ -127,6 +127,7 @@ async def get_albums(
127
  image_url=f"{APP_DOMAIN}uploads/{photo.filename}",
128
  tags=photo.get_tags(),
129
  caption=photo.caption,
 
130
  distance=float(distance_val),
131
  )
132
  )
 
37
  image_url = f"{APP_DOMAIN}uploads/{photo.filename}",
38
  tags=photo.get_tags(),
39
  caption=photo.caption,
40
+ description=photo.description,
41
  created_at=photo.created_at,
42
  )
43
 
 
72
  image_url = f"{APP_DOMAIN}uploads/{photo.filename}",
73
  tags=photo.get_tags(),
74
  caption=photo.caption,
75
+ description=photo.description,
76
  created_at=photo.created_at,
77
  )
78
  for photo in photos
 
89
  """
90
 
91
  search_engine = SearchEngine()
92
+ albums_ids = search_engine.create_albums(top_k=top_k)
93
  APP_DOMAIN = os.getenv("APP_DOMAIN") or "http://127.0.0.1:8000/"
94
  summarizer = TextSummarizer()
95
 
 
127
  image_url=f"{APP_DOMAIN}uploads/{photo.filename}",
128
  tags=photo.get_tags(),
129
  caption=photo.caption,
130
+ description=photo.description,
131
  distance=float(distance_val),
132
  )
133
  )
cloudzy/routes/search.py CHANGED
@@ -58,6 +58,7 @@ async def search_photos(
58
  image_url=f"{APP_DOMAIN}uploads/{photo.filename}",
59
  tags=photo.get_tags(),
60
  caption=photo.caption,
 
61
  distance=distance,
62
  )
63
  )
 
58
  image_url=f"{APP_DOMAIN}uploads/{photo.filename}",
59
  tags=photo.get_tags(),
60
  caption=photo.caption,
61
+ description=photo.description,
62
  distance=distance,
63
  )
64
  )
cloudzy/routes/upload.py CHANGED
@@ -127,6 +127,7 @@ def process_image_in_background(photo_id: int, filepath: str):
127
  photo = session.exec(select(Photo).where(Photo.id == photo_id)).first()
128
  if photo:
129
  photo.caption = caption
 
130
  photo.set_tags(tags)
131
  photo.set_embedding(embedding.tolist())
132
  session.add(photo)
 
127
  photo = session.exec(select(Photo).where(Photo.id == photo_id)).first()
128
  if photo:
129
  photo.caption = caption
130
+ photo.description = description
131
  photo.set_tags(tags)
132
  photo.set_embedding(embedding.tolist())
133
  session.add(photo)
cloudzy/schemas.py CHANGED
@@ -9,6 +9,7 @@ class PhotoResponse(BaseModel):
9
  id: int
10
  filename: str
11
  image_url: str
 
12
  tags: List[str]
13
  caption: str
14
  created_at: datetime
@@ -31,6 +32,7 @@ class SearchResult(BaseModel):
31
  image_url: str
32
  tags: List[str]
33
  caption: str
 
34
  distance: float # L2 distance (lower is more similar)
35
 
36
  class Config:
@@ -60,6 +62,7 @@ class PhotoItem(BaseModel):
60
  image_url: str
61
  tags: List[str]
62
  caption: str
 
63
  distance: float
64
 
65
  class AlbumItem(BaseModel):
 
9
  id: int
10
  filename: str
11
  image_url: str
12
+ description: str
13
  tags: List[str]
14
  caption: str
15
  created_at: datetime
 
32
  image_url: str
33
  tags: List[str]
34
  caption: str
35
+ description: str
36
  distance: float # L2 distance (lower is more similar)
37
 
38
  class Config:
 
62
  image_url: str
63
  tags: List[str]
64
  caption: str
65
+ description: str
66
  distance: float
67
 
68
  class AlbumItem(BaseModel):