sitatech commited on
Commit
060ca74
·
1 Parent(s): 0651cc5

[vtry] Fix and finalize virtual_try

Browse files
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  __pycache__
2
- .DS_Store
 
 
1
  __pycache__
2
+ .DS_Store
3
+ .env
llm/app.py CHANGED
@@ -59,6 +59,8 @@ def serve_llm():
59
  str(VLLM_PORT),
60
  "--api-key",
61
  os.environ["API_KEY"],
 
 
62
  ]
63
 
64
  subprocess.Popen(cmd)
 
59
  str(VLLM_PORT),
60
  "--api-key",
61
  os.environ["API_KEY"],
62
+ "--tensor-parallel-size",
63
+ str(N_GPU),
64
  ]
65
 
66
  subprocess.Popen(cmd)
mcp_host/agent.py CHANGED
@@ -44,7 +44,7 @@ If a tool requires an input that you don't have based on your knowledge and the
44
  self,
45
  model_name: str = "RedHatAI/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic",
46
  openai_api_key: str = os.getenv("OPENAI_API_KEY", ""),
47
- openai_api_base_url: str = "TODO",
48
  image_uploader: ImageUploader = ImageUploader(),
49
  ):
50
  self.agora_client = AgoraMCPClient(unique_name="Agora")
 
44
  self,
45
  model_name: str = "RedHatAI/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic",
46
  openai_api_key: str = os.getenv("OPENAI_API_KEY", ""),
47
+ openai_api_base_url: str = os.getenv("OPENAI_API_BASE_URL", ""),
48
  image_uploader: ImageUploader = ImageUploader(),
49
  ):
50
  self.agora_client = AgoraMCPClient(unique_name="Agora")
mcp_server.py CHANGED
@@ -47,7 +47,7 @@ def try_item_with_masking(
47
  [IMAGE2] The same skirt is worn by a woman standing in a realistic lifestyle setting, the skirt fits naturally.
48
 
49
  Args:
50
- prompt: A prompt for the diffusion model to use for inpainting.
51
  item_image_url: URL of the item image to try.
52
  target_image_url: URL of the target image where the item will be tried.
53
  mask_image_url: Optional URL of a mask image to use.
@@ -85,14 +85,14 @@ def try_item_with_auto_masking(
85
  [IMAGE2] The same sofa is shown in a living room in a realistic lifestyle setting, the sofa fits in naturally with the room decor.
86
 
87
  For cases where a similar item is present but masking it won't cover enough area for the item to be applied, if you can, you should use a composite mask prompt.
88
- For example if the item is a long-sleeved shirt and the target image is a person wearing a short-sleeved t-shirt, the masking prompt could be "t-shirt, arms".
89
  If the the item is a dress and the target image is a person wearing a t-shirt and jeans, the masking prompt could be "t-shirt, jeans, arms, legs".
90
  Make sure the mask prompt include all the parts where the item will be applied to.
91
 
92
  This tool requires a similar item to be present in the target image, so it can generate a mask of the item using the masking_prompt.
93
 
94
  Args:
95
- prompt: A prompt for the diffusion model to use for inpainting.
96
  item_image_url: URL of the item image to try.
97
  target_image_url: URL of the target image where the item will be tried.
98
  masking_prompt: Prompt for generating a mask of the corresponding item in the target image. It need to be short and descriptive, e.g. "red dress", "blue sofa", "tire", "skirt, legs" etc.
 
47
  [IMAGE2] The same skirt is worn by a woman standing in a realistic lifestyle setting, the skirt fits naturally.
48
 
49
  Args:
50
+ prompt: A prompt for the diffusion model to use for inpainting. Be specific, e.g: for a short dress, say short dress, not just dress.
51
  item_image_url: URL of the item image to try.
52
  target_image_url: URL of the target image where the item will be tried.
53
  mask_image_url: Optional URL of a mask image to use.
 
85
  [IMAGE2] The same sofa is shown in a living room in a realistic lifestyle setting, the sofa fits in naturally with the room decor.
86
 
87
  For cases where a similar item is present but masking it won't cover enough area for the item to be applied, if you can, you should use a composite mask prompt.
88
+ For example if the item is a long-sleeved shirt and the target image is a person wearing a short-sleeved t-shirt, the masking prompt could be "t-shirt, arms, neck".
89
  If the the item is a dress and the target image is a person wearing a t-shirt and jeans, the masking prompt could be "t-shirt, jeans, arms, legs".
90
  Make sure the mask prompt include all the parts where the item will be applied to.
91
 
92
  This tool requires a similar item to be present in the target image, so it can generate a mask of the item using the masking_prompt.
93
 
94
  Args:
95
+ prompt: A prompt for the diffusion model to use for inpainting. Be specific, e.g: for a long-sleeved shirt, say long-sleeved shirt, not just shirt.
96
  item_image_url: URL of the item image to try.
97
  target_image_url: URL of the target image where the item will be tried.
98
  masking_prompt: Prompt for generating a mask of the corresponding item in the target image. It need to be short and descriptive, e.g. "red dress", "blue sofa", "tire", "skirt, legs" etc.
virtual_try/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ test_data
virtual_try/app.py CHANGED
@@ -14,9 +14,9 @@ with image.imports():
14
  from nunchaku.utils import get_precision
15
  from nunchaku.lora.flux.compose import compose_lora
16
 
17
- from virtual_try.auto_masker import AutoInpaintMaskGenerator
18
 
19
- TransformType = Callable[[Image.Image | np.ndarray], torch.Tensor]
20
 
21
  app = modal.App("vibe-shopping")
22
 
@@ -120,9 +120,7 @@ class VirtualTryModel:
120
  mask_tensor = mask_preprocessor(mask)
121
 
122
  # Create concatenated images along the width axis
123
- inpaint_image = torch.cat(
124
- [item_to_try_tensor, image_tensor], dim=2
125
- )
126
  extended_mask = torch.cat([torch.zeros_like(mask_tensor), mask_tensor], dim=2)
127
 
128
  prompt = prompt or (
@@ -148,3 +146,50 @@ class VirtualTryModel:
148
  byte_stream = BytesIO()
149
  output_image.save(byte_stream, format="WEBP", quality=90)
150
  return byte_stream.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from nunchaku.utils import get_precision
15
  from nunchaku.lora.flux.compose import compose_lora
16
 
17
+ from auto_masker import AutoInpaintMaskGenerator
18
 
19
+ TransformType = Callable[[Image.Image | np.ndarray], torch.Tensor]
20
 
21
  app = modal.App("vibe-shopping")
22
 
 
120
  mask_tensor = mask_preprocessor(mask)
121
 
122
  # Create concatenated images along the width axis
123
+ inpaint_image = torch.cat([item_to_try_tensor, image_tensor], dim=2)
 
 
124
  extended_mask = torch.cat([torch.zeros_like(mask_tensor), mask_tensor], dim=2)
125
 
126
  prompt = prompt or (
 
146
  byte_stream = BytesIO()
147
  output_image.save(byte_stream, format="WEBP", quality=90)
148
  return byte_stream.getvalue()
149
+
150
+
151
+ ###### ------ FOR TESTING PURPOSES ONLY ------ ######
152
+ @app.local_entrypoint()
153
+ def main(twice: bool = True):
154
+ import time
155
+ from pathlib import Path
156
+
157
+ test_data_dir = Path(__file__).parent / "test_data"
158
+ with open(test_data_dir / "target_image.jpg", "rb") as f:
159
+ target_image_bytes = f.read()
160
+ with open(test_data_dir / "item_to_try.jpg", "rb") as f:
161
+ item_to_try_bytes = f.read()
162
+ with open(test_data_dir / "item_to_try2.png", "rb") as f:
163
+ item_to_try_2_bytes = f.read()
164
+
165
+ prompt = (
166
+ "The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; "
167
+ "[IMAGE1] Detailed product shot of a clothing"
168
+ "[IMAGE2] The same cloth is worn by a model in a lifestyle setting."
169
+ )
170
+
171
+ t0 = time.time()
172
+ image_bytes = VirtualTryModel().try_it.remote(
173
+ prompt=prompt,
174
+ image_bytes=target_image_bytes,
175
+ item_to_try_bytes=item_to_try_bytes,
176
+ masking_prompt="t-shirt, arms, neck",
177
+ )
178
+ output_path = test_data_dir / "output1.jpg"
179
+ output_path.parent.mkdir(exist_ok=True, parents=True)
180
+ output_path.write_bytes(image_bytes)
181
+ print(f"🎨 first inference latency: {time.time() - t0:.2f} seconds")
182
+
183
+ if twice:
184
+ t0 = time.time()
185
+ image_bytes = VirtualTryModel().try_it.remote(
186
+ prompt=prompt,
187
+ image_bytes=target_image_bytes,
188
+ item_to_try_bytes=item_to_try_2_bytes,
189
+ masking_prompt="t-shirt, arms",
190
+ )
191
+ print(f"🎨 second inference latency: {time.time() - t0:.2f} seconds")
192
+
193
+ output_path = test_data_dir / "output2.jpg"
194
+ output_path.parent.mkdir(exist_ok=True, parents=True)
195
+ output_path.write_bytes(image_bytes)
virtual_try/auto_masker.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  from PIL import Image
3
  from huggingface_hub import hf_hub_download
@@ -46,19 +47,26 @@ class AutoInpaintMaskGenerator:
46
  )[0]
47
 
48
  masks = result["masks"] # (N, H, W)
49
- scores = result["mask_scores"] # (N,)
 
 
 
 
50
 
51
  if len(masks) == 0:
52
  raise ValueError("No masks found.")
53
 
54
  # Filter masks by score threshold
55
- valid_indices = np.where(scores >= threshold)[0]
56
  if len(valid_indices) == 0:
57
  raise ValueError("No masks scored the required threshold.")
58
-
59
- best_idx = valid_indices[np.argmax(scores[valid_indices])]
60
- mask = masks[best_idx]
61
 
62
  # Convert to uint8 binary mask for inpainting
63
- binary_mask = (mask.astype(np.uint8)) * 255 # 0 or 255
64
- return binary_mask
 
 
 
 
 
1
+ import cv2
2
  import numpy as np
3
  from PIL import Image
4
  from huggingface_hub import hf_hub_download
 
47
  )[0]
48
 
49
  masks = result["masks"] # (N, H, W)
50
+ scores = np.atleast_1d(result["mask_scores"]) # Ensure it's always at least 1D
51
+
52
+ # If only one mask returned, expand dims
53
+ if masks.ndim == 2:
54
+ masks = masks[np.newaxis, :, :] # Make it (1, H, W)
55
 
56
  if len(masks) == 0:
57
  raise ValueError("No masks found.")
58
 
59
  # Filter masks by score threshold
60
+ valid_indices = scores >= threshold
61
  if len(valid_indices) == 0:
62
  raise ValueError("No masks scored the required threshold.")
63
+
64
+ combined_mask = np.any(masks[valid_indices], axis=0)
 
65
 
66
  # Convert to uint8 binary mask for inpainting
67
+ binary_mask = (combined_mask.astype(np.uint8)) * 255 # 0 or 255
68
+
69
+ # Apply dilation
70
+ kernel = np.ones((10, 10), np.uint8)
71
+ dilated_mask = cv2.dilate(binary_mask, kernel, iterations=1)
72
+ return dilated_mask
virtual_try/configs.py CHANGED
@@ -1,18 +1,23 @@
1
  import modal
 
2
 
3
  image = (
4
  modal.Image.debian_slim(python_version="3.12")
 
5
  .pip_install(
6
  "torch==2.7.0",
7
  "torchvision",
8
  "diffusers==0.33.1",
9
  "transformers==4.52.4",
10
  "accelerate==1.7.0",
 
11
  "huggingface_hub[hf_transfer]==0.32.4",
12
- "git+https://github.com/luca-medeiros/lang-segment-anything.git@e9af744d999d85eb4d0bd59a83342ecdc2bd2461",
13
- "https://github.com/mit-han-lab/nunchaku/releases/download/v0.3.0/nunchaku-0.3.0+torch2.7-cp312-cp312-linux_x86_64.whl#sha256=ed28665515075050c8ef1bacd16845b85aa4335f6c760d6fa716d3b090909d8d7",
14
  )
15
  .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
 
 
16
  )
17
 
18
  hf_cache_vol = modal.Volume.from_name(
@@ -28,8 +33,6 @@ MINUTE = 60
28
  modal_class_config = {
29
  "image": image,
30
  "gpu": "A100-40GB",
31
- "cpu": 4, # 8vCPUs
32
- "memory": 16, # 16 GB RAM
33
  "volumes": {
34
  "/root/.cache/huggingface": hf_cache_vol,
35
  },
 
1
  import modal
2
+ from pathlib import Path
3
 
4
  image = (
5
  modal.Image.debian_slim(python_version="3.12")
6
+ .apt_install("git")
7
  .pip_install(
8
  "torch==2.7.0",
9
  "torchvision",
10
  "diffusers==0.33.1",
11
  "transformers==4.52.4",
12
  "accelerate==1.7.0",
13
+ "opencv-python-headless",
14
  "huggingface_hub[hf_transfer]==0.32.4",
15
+ "git+https://github.com/sitatec/lang-segment-anything.git",
16
+ "https://github.com/mit-han-lab/nunchaku/releases/download/v0.3.1dev20250609/nunchaku-0.3.1.dev20250609+torch2.7-cp312-cp312-linux_x86_64.whl#sha256=1518f6c02358545fd0336a6a74547e2c875603b381d5ce75b1664f981105b141",
17
  )
18
  .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
19
+ .add_local_file(str(Path(__file__).resolve()), "/root/configs.py")
20
+ .add_local_file(str(Path(__file__).parent / "auto_masker.py"), "/root/auto_masker.py")
21
  )
22
 
23
  hf_cache_vol = modal.Volume.from_name(
 
33
  modal_class_config = {
34
  "image": image,
35
  "gpu": "A100-40GB",
 
 
36
  "volumes": {
37
  "/root/.cache/huggingface": hf_cache_vol,
38
  },