1inkusFace ford442 commited on
Commit
bc2be42
·
verified ·
1 Parent(s): 107b524

Gemini 2.5 (#1)

Browse files

- Gemini 2.5 (2c1b1b6f7fbd1281408e4cfe8e38e2e895a8b5d7)


Co-authored-by: Noah Cohn <ford442@users.noreply.huggingface.co>

Files changed (1) hide show
  1. pipeline_stable_diffusion_3_ipa.py +82 -68
pipeline_stable_diffusion_3_ipa.py CHANGED
@@ -920,20 +920,27 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
920
  key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
921
  print(f"=> loading ip_adapter: {key_name}")
922
 
923
-
924
  @torch.inference_mode()
925
- def encode_clip_image_emb(self, clip_image, device, dtype):
926
- if isinstance(clip_image, Image.Image):
927
- clip_image = [clip_image]
928
- # clip
929
- clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
 
 
 
 
 
 
 
 
930
  clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
931
  clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
932
- clip_image_embeds = torch.cat([torch.zeros_like(clip_image_embeds), clip_image_embeds], dim=0)
933
-
934
  return clip_image_embeds
935
-
936
-
937
 
938
  @torch.no_grad()
939
  @replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -1104,6 +1111,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1104
  batch_size = len(prompt)
1105
  else:
1106
  batch_size = prompt_embeds.shape[0]
 
 
 
 
 
1107
 
1108
  device = self._execution_device
1109
  dtype = self.transformer.dtype
@@ -1141,66 +1153,62 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1141
 
1142
  prompt_embeds = prompt_embeds * text_scale
1143
 
1144
- image_prompt_embeds_list = []
1145
-
1146
  # 3. prepare clip emb
1147
- if clip_image != None:
1148
- print('Using primary image.')
1149
- clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1150
- #clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
1151
- #with torch.no_grad():
1152
- clip_image_embeds_1 = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
1153
- print('clip output shape: ', clip_image_embeds_1.shape)
1154
- clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
1155
- clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
1156
- print('encoder output shape: ', clip_image_embeds_1.shape)
1157
- clip_image_embeds_1 = clip_image_embeds_1 * scale_1
1158
- image_prompt_embeds_list.append(clip_image_embeds_1)
1159
- if clip_image_2 != None:
1160
- print('Using secondary image.')
1161
- clip_image_2 = clip_image_2.resize((max(clip_image_2.size), max(clip_image_2.size)))
1162
- #with torch.no_grad():
1163
- clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
1164
- clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
1165
- clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
1166
- clip_image_embeds_2 = clip_image_embeds_2 * scale_2
1167
- image_prompt_embeds_list.append(clip_image_embeds_2)
1168
- if clip_image_3 != None:
1169
- print('Using tertiary image.')
1170
- clip_image_3 = clip_image_3.resize((max(clip_image_3.size), max(clip_image_3.size)))
1171
- #with torch.no_grad():
1172
- clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
1173
- clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
1174
- clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
1175
- clip_image_embeds_3 = clip_image_embeds_3 * scale_3
1176
- image_prompt_embeds_list.append(clip_image_embeds_3)
1177
- if clip_image_4 != None:
1178
- print('Using quaternary image.')
1179
- clip_image_4 = clip_image_4.resize((max(clip_image_4.size), max(clip_image_4.size)))
1180
- #with torch.no_grad():
1181
- clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
1182
- clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
1183
- clip_image_embeds_4 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
1184
- clip_image_embeds_4 = clip_image_embeds_4 * scale_4
1185
- image_prompt_embeds_list.append(clip_image_embeds_4)
1186
- if clip_image_5 != None:
1187
- print('Using quinary image.')
1188
- clip_image_5 = clip_image_5.resize((max(clip_image_5.size), max(clip_image_5.size)))
1189
- #with torch.no_grad():
1190
- clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
1191
- clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
1192
- clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
1193
- clip_image_embeds_5 = clip_image_embeds_5 * scale_5
1194
- image_prompt_embeds_list.append(clip_image_embeds_5)
1195
-
1196
- #cat, but not mean
1197
- clip_image_embeds_cat = torch.cat(image_prompt_embeds_list)
1198
- print('catted embeds list without mean: ', clip_image_embeds_cat.shape)
1199
- zeros_tensor = torch.zeros_like(clip_image_embeds_cat)
1200
- print('zeros: ',zeros_tensor.shape)
1201
- clip_image_embeds = torch.cat([zeros_tensor, clip_image_embeds_cat], dim=1)
1202
- print('embeds shape: ', clip_image_embeds.shape)
1203
 
 
 
 
 
 
 
1204
  # 4. Prepare timesteps
1205
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1206
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -1230,12 +1238,17 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1230
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1231
  timestep = t.expand(latent_model_input.shape[0])
1232
 
 
 
 
 
1233
  image_prompt_embeds, timestep_emb = self.image_proj_model(
1234
  clip_image_embeds,
1235
  timestep.to(dtype=latents.dtype),
1236
  need_temb=True
1237
  )
1238
 
 
1239
  joint_attention_kwargs = dict(
1240
  emb_dict=dict(
1241
  ip_hidden_states=image_prompt_embeds,
@@ -1244,6 +1257,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1244
  )
1245
  )
1246
 
 
1247
  noise_pred = self.transformer(
1248
  hidden_states=latent_model_input,
1249
  timestep=timestep,
 
920
  key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
921
  print(f"=> loading ip_adapter: {key_name}")
922
 
923
+ # <<< START OF ADDED METHOD >>>
924
  @torch.inference_mode()
925
+ def _encode_clip_image_emb(self, clip_image: Image.Image, device, dtype) -> torch.FloatTensor:
926
+ """
927
+ Helper method to encode a single PIL image into CLIP embeddings.
928
+ Resizes the image and returns the hidden states from the penultimate layer.
929
+ """
930
+ if not isinstance(clip_image, Image.Image):
931
+ raise TypeError("clip_image must be a PIL.Image.Image")
932
+
933
+ # Resize image
934
+ clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
935
+
936
+ # Process and encode
937
+ clip_image_tensor = self.clip_image_processor(images=[clip_image], return_tensors="pt").pixel_values
938
  clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
939
  clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
940
+
941
+ # Returns shape [1, seq_len, embed_dim]
942
  return clip_image_embeds
943
+ # <<< END OF ADDED METHOD >>>
 
944
 
945
  @torch.no_grad()
946
  @replace_example_docstring(EXAMPLE_DOC_STRING)
 
1111
  batch_size = len(prompt)
1112
  else:
1113
  batch_size = prompt_embeds.shape[0]
1114
+
1115
+ # Ensure batch size is 1 for IP-Adapter logic
1116
+ if batch_size != 1 and (clip_image is not None or clip_image_2 is not None):
1117
+ logger.warning("Batch processing with multiple IP-Adapter images is not fully supported. Forcing batch size to 1.")
1118
+ batch_size = 1
1119
 
1120
  device = self._execution_device
1121
  dtype = self.transformer.dtype
 
1153
 
1154
  prompt_embeds = prompt_embeds * text_scale
1155
 
1156
+ # <<< START OF MODIFIED SECTION >>>
 
1157
  # 3. prepare clip emb
1158
+ image_embeds_list = []
1159
+ scales_list = []
1160
+
1161
+ # Collect all provided image embeddings and their scales
1162
+ if clip_image is not None:
1163
+ print("Processing primary image.")
1164
+ image_embeds_list.append(self._encode_clip_image_emb(clip_image, device, dtype))
1165
+ scales_list.append(scale_1)
1166
+ if clip_image_2 is not None:
1167
+ print("Processing secondary image.")
1168
+ image_embeds_list.append(self._encode_clip_image_emb(clip_image_2, device, dtype))
1169
+ scales_list.append(scale_2)
1170
+ if clip_image_3 is not None:
1171
+ print("Processing tertiary image.")
1172
+ image_embeds_list.append(self._encode_clip_image_emb(clip_image_3, device, dtype))
1173
+ scales_list.append(scale_3)
1174
+ if clip_image_4 is not None:
1175
+ print("Processing quaternary image.")
1176
+ image_embeds_list.append(self._encode_clip_image_emb(clip_image_4, device, dtype))
1177
+ scales_list.append(scale_4)
1178
+ if clip_image_5 is not None:
1179
+ print("Processing quinary image.")
1180
+ image_embeds_list.append(self._encode_clip_image_emb(clip_image_5, device, dtype))
1181
+ scales_list.append(scale_5)
1182
+
1183
+ if not image_embeds_list:
1184
+ # If no images provided, create a zero tensor.
1185
+ # We need the expected shape. We'll encode a dummy image to get it.
1186
+ print("No IP-Adapter image provided, using zeros.")
1187
+ dummy_image = Image.new('RGB', (256, 256), (0, 0, 0))
1188
+ cond_image_embeds = self._encode_clip_image_emb(dummy_image, device, dtype)
1189
+ cond_image_embeds = torch.zeros_like(cond_image_embeds)
1190
+ else:
1191
+ # Stack all embeddings. Shape: [Num_Images, 1, Seq_Len, Embed_Dim]
1192
+ all_embeds = torch.stack(image_embeds_list, dim=0)
1193
+
1194
+ # Create scales tensor. Shape: [Num_Images]
1195
+ scales = torch.tensor(scales_list, device=device, dtype=dtype)
1196
+ # Reshape scales for broadcasting: [Num_Images, 1, 1, 1]
1197
+ scales = scales.view(-1, 1, 1, 1)
1198
+
1199
+ # Apply scales and then average along the Num_Images dimension
1200
+ scaled_embeds = all_embeds * scales
1201
+ cond_image_embeds = torch.mean(scaled_embeds, dim=0) # Shape: [1, Seq_Len, Embed_Dim]
1202
+
1203
+ # Create unconditional image embeds (zeros)
1204
+ uncond_image_embeds = torch.zeros_like(cond_image_embeds)
 
 
 
 
 
 
 
 
 
1205
 
1206
+ # Stack for Classifier-Free Guidance. Shape: [2, Seq_Len, Embed_Dim]
1207
+ clip_image_embeds = torch.cat([uncond_image_embeds, cond_image_embeds], dim=0)
1208
+
1209
+ print(f"Final combined image embeds shape: {clip_image_embeds.shape}")
1210
+ # <<< END OF MODIFIED SECTION >>>
1211
+
1212
  # 4. Prepare timesteps
1213
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1214
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
 
1238
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1239
  timestep = t.expand(latent_model_input.shape[0])
1240
 
1241
+ # This is now correct:
1242
+ # clip_image_embeds has shape [2, Seq, Dim]
1243
+ # timestep has shape [2]
1244
+ # This returns image_prompt_embeds with shape [2, N_Queries, Proj_Dim]
1245
  image_prompt_embeds, timestep_emb = self.image_proj_model(
1246
  clip_image_embeds,
1247
  timestep.to(dtype=latents.dtype),
1248
  need_temb=True
1249
  )
1250
 
1251
+ # This is also correct. The processor will get the CFG-batched image embeds.
1252
  joint_attention_kwargs = dict(
1253
  emb_dict=dict(
1254
  ip_hidden_states=image_prompt_embeds,
 
1257
  )
1258
  )
1259
 
1260
+ # The transformer call is also correct, as latent_model_input and prompt_embeds are also CFG-batched.
1261
  noise_pred = self.transformer(
1262
  hidden_states=latent_model_input,
1263
  timestep=timestep,