Spaces:
Running
on
Zero
Running
on
Zero
Gemini 2.5 (#1)
Browse files- Gemini 2.5 (2c1b1b6f7fbd1281408e4cfe8e38e2e895a8b5d7)
Co-authored-by: Noah Cohn <ford442@users.noreply.huggingface.co>
pipeline_stable_diffusion_3_ipa.py
CHANGED
|
@@ -920,20 +920,27 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 920 |
key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
|
| 921 |
print(f"=> loading ip_adapter: {key_name}")
|
| 922 |
|
| 923 |
-
|
| 924 |
@torch.inference_mode()
|
| 925 |
-
def
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 930 |
clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
|
| 931 |
clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
|
| 932 |
-
|
| 933 |
-
|
| 934 |
return clip_image_embeds
|
| 935 |
-
|
| 936 |
-
|
| 937 |
|
| 938 |
@torch.no_grad()
|
| 939 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
@@ -1104,6 +1111,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 1104 |
batch_size = len(prompt)
|
| 1105 |
else:
|
| 1106 |
batch_size = prompt_embeds.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1107 |
|
| 1108 |
device = self._execution_device
|
| 1109 |
dtype = self.transformer.dtype
|
|
@@ -1141,66 +1153,62 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 1141 |
|
| 1142 |
prompt_embeds = prompt_embeds * text_scale
|
| 1143 |
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
# 3. prepare clip emb
|
| 1147 |
-
|
| 1148 |
-
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
| 1153 |
-
|
| 1154 |
-
|
| 1155 |
-
|
| 1156 |
-
print(
|
| 1157 |
-
|
| 1158 |
-
|
| 1159 |
-
if
|
| 1160 |
-
print(
|
| 1161 |
-
|
| 1162 |
-
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
-
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
#
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
|
| 1184 |
-
|
| 1185 |
-
|
| 1186 |
-
|
| 1187 |
-
|
| 1188 |
-
|
| 1189 |
-
|
| 1190 |
-
|
| 1191 |
-
|
| 1192 |
-
|
| 1193 |
-
|
| 1194 |
-
image_prompt_embeds_list.append(clip_image_embeds_5)
|
| 1195 |
-
|
| 1196 |
-
#cat, but not mean
|
| 1197 |
-
clip_image_embeds_cat = torch.cat(image_prompt_embeds_list)
|
| 1198 |
-
print('catted embeds list without mean: ', clip_image_embeds_cat.shape)
|
| 1199 |
-
zeros_tensor = torch.zeros_like(clip_image_embeds_cat)
|
| 1200 |
-
print('zeros: ',zeros_tensor.shape)
|
| 1201 |
-
clip_image_embeds = torch.cat([zeros_tensor, clip_image_embeds_cat], dim=1)
|
| 1202 |
-
print('embeds shape: ', clip_image_embeds.shape)
|
| 1203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1204 |
# 4. Prepare timesteps
|
| 1205 |
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 1206 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
|
@@ -1230,12 +1238,17 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 1230 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1231 |
timestep = t.expand(latent_model_input.shape[0])
|
| 1232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1233 |
image_prompt_embeds, timestep_emb = self.image_proj_model(
|
| 1234 |
clip_image_embeds,
|
| 1235 |
timestep.to(dtype=latents.dtype),
|
| 1236 |
need_temb=True
|
| 1237 |
)
|
| 1238 |
|
|
|
|
| 1239 |
joint_attention_kwargs = dict(
|
| 1240 |
emb_dict=dict(
|
| 1241 |
ip_hidden_states=image_prompt_embeds,
|
|
@@ -1244,6 +1257,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 1244 |
)
|
| 1245 |
)
|
| 1246 |
|
|
|
|
| 1247 |
noise_pred = self.transformer(
|
| 1248 |
hidden_states=latent_model_input,
|
| 1249 |
timestep=timestep,
|
|
|
|
| 920 |
key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
|
| 921 |
print(f"=> loading ip_adapter: {key_name}")
|
| 922 |
|
| 923 |
+
# <<< START OF ADDED METHOD >>>
|
| 924 |
@torch.inference_mode()
|
| 925 |
+
def _encode_clip_image_emb(self, clip_image: Image.Image, device, dtype) -> torch.FloatTensor:
|
| 926 |
+
"""
|
| 927 |
+
Helper method to encode a single PIL image into CLIP embeddings.
|
| 928 |
+
Resizes the image and returns the hidden states from the penultimate layer.
|
| 929 |
+
"""
|
| 930 |
+
if not isinstance(clip_image, Image.Image):
|
| 931 |
+
raise TypeError("clip_image must be a PIL.Image.Image")
|
| 932 |
+
|
| 933 |
+
# Resize image
|
| 934 |
+
clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
|
| 935 |
+
|
| 936 |
+
# Process and encode
|
| 937 |
+
clip_image_tensor = self.clip_image_processor(images=[clip_image], return_tensors="pt").pixel_values
|
| 938 |
clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
|
| 939 |
clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
|
| 940 |
+
|
| 941 |
+
# Returns shape [1, seq_len, embed_dim]
|
| 942 |
return clip_image_embeds
|
| 943 |
+
# <<< END OF ADDED METHOD >>>
|
|
|
|
| 944 |
|
| 945 |
@torch.no_grad()
|
| 946 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
|
|
| 1111 |
batch_size = len(prompt)
|
| 1112 |
else:
|
| 1113 |
batch_size = prompt_embeds.shape[0]
|
| 1114 |
+
|
| 1115 |
+
# Ensure batch size is 1 for IP-Adapter logic
|
| 1116 |
+
if batch_size != 1 and (clip_image is not None or clip_image_2 is not None):
|
| 1117 |
+
logger.warning("Batch processing with multiple IP-Adapter images is not fully supported. Forcing batch size to 1.")
|
| 1118 |
+
batch_size = 1
|
| 1119 |
|
| 1120 |
device = self._execution_device
|
| 1121 |
dtype = self.transformer.dtype
|
|
|
|
| 1153 |
|
| 1154 |
prompt_embeds = prompt_embeds * text_scale
|
| 1155 |
|
| 1156 |
+
# <<< START OF MODIFIED SECTION >>>
|
|
|
|
| 1157 |
# 3. prepare clip emb
|
| 1158 |
+
image_embeds_list = []
|
| 1159 |
+
scales_list = []
|
| 1160 |
+
|
| 1161 |
+
# Collect all provided image embeddings and their scales
|
| 1162 |
+
if clip_image is not None:
|
| 1163 |
+
print("Processing primary image.")
|
| 1164 |
+
image_embeds_list.append(self._encode_clip_image_emb(clip_image, device, dtype))
|
| 1165 |
+
scales_list.append(scale_1)
|
| 1166 |
+
if clip_image_2 is not None:
|
| 1167 |
+
print("Processing secondary image.")
|
| 1168 |
+
image_embeds_list.append(self._encode_clip_image_emb(clip_image_2, device, dtype))
|
| 1169 |
+
scales_list.append(scale_2)
|
| 1170 |
+
if clip_image_3 is not None:
|
| 1171 |
+
print("Processing tertiary image.")
|
| 1172 |
+
image_embeds_list.append(self._encode_clip_image_emb(clip_image_3, device, dtype))
|
| 1173 |
+
scales_list.append(scale_3)
|
| 1174 |
+
if clip_image_4 is not None:
|
| 1175 |
+
print("Processing quaternary image.")
|
| 1176 |
+
image_embeds_list.append(self._encode_clip_image_emb(clip_image_4, device, dtype))
|
| 1177 |
+
scales_list.append(scale_4)
|
| 1178 |
+
if clip_image_5 is not None:
|
| 1179 |
+
print("Processing quinary image.")
|
| 1180 |
+
image_embeds_list.append(self._encode_clip_image_emb(clip_image_5, device, dtype))
|
| 1181 |
+
scales_list.append(scale_5)
|
| 1182 |
+
|
| 1183 |
+
if not image_embeds_list:
|
| 1184 |
+
# If no images provided, create a zero tensor.
|
| 1185 |
+
# We need the expected shape. We'll encode a dummy image to get it.
|
| 1186 |
+
print("No IP-Adapter image provided, using zeros.")
|
| 1187 |
+
dummy_image = Image.new('RGB', (256, 256), (0, 0, 0))
|
| 1188 |
+
cond_image_embeds = self._encode_clip_image_emb(dummy_image, device, dtype)
|
| 1189 |
+
cond_image_embeds = torch.zeros_like(cond_image_embeds)
|
| 1190 |
+
else:
|
| 1191 |
+
# Stack all embeddings. Shape: [Num_Images, 1, Seq_Len, Embed_Dim]
|
| 1192 |
+
all_embeds = torch.stack(image_embeds_list, dim=0)
|
| 1193 |
+
|
| 1194 |
+
# Create scales tensor. Shape: [Num_Images]
|
| 1195 |
+
scales = torch.tensor(scales_list, device=device, dtype=dtype)
|
| 1196 |
+
# Reshape scales for broadcasting: [Num_Images, 1, 1, 1]
|
| 1197 |
+
scales = scales.view(-1, 1, 1, 1)
|
| 1198 |
+
|
| 1199 |
+
# Apply scales and then average along the Num_Images dimension
|
| 1200 |
+
scaled_embeds = all_embeds * scales
|
| 1201 |
+
cond_image_embeds = torch.mean(scaled_embeds, dim=0) # Shape: [1, Seq_Len, Embed_Dim]
|
| 1202 |
+
|
| 1203 |
+
# Create unconditional image embeds (zeros)
|
| 1204 |
+
uncond_image_embeds = torch.zeros_like(cond_image_embeds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1205 |
|
| 1206 |
+
# Stack for Classifier-Free Guidance. Shape: [2, Seq_Len, Embed_Dim]
|
| 1207 |
+
clip_image_embeds = torch.cat([uncond_image_embeds, cond_image_embeds], dim=0)
|
| 1208 |
+
|
| 1209 |
+
print(f"Final combined image embeds shape: {clip_image_embeds.shape}")
|
| 1210 |
+
# <<< END OF MODIFIED SECTION >>>
|
| 1211 |
+
|
| 1212 |
# 4. Prepare timesteps
|
| 1213 |
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 1214 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
|
|
|
| 1238 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1239 |
timestep = t.expand(latent_model_input.shape[0])
|
| 1240 |
|
| 1241 |
+
# This is now correct:
|
| 1242 |
+
# clip_image_embeds has shape [2, Seq, Dim]
|
| 1243 |
+
# timestep has shape [2]
|
| 1244 |
+
# This returns image_prompt_embeds with shape [2, N_Queries, Proj_Dim]
|
| 1245 |
image_prompt_embeds, timestep_emb = self.image_proj_model(
|
| 1246 |
clip_image_embeds,
|
| 1247 |
timestep.to(dtype=latents.dtype),
|
| 1248 |
need_temb=True
|
| 1249 |
)
|
| 1250 |
|
| 1251 |
+
# This is also correct. The processor will get the CFG-batched image embeds.
|
| 1252 |
joint_attention_kwargs = dict(
|
| 1253 |
emb_dict=dict(
|
| 1254 |
ip_hidden_states=image_prompt_embeds,
|
|
|
|
| 1257 |
)
|
| 1258 |
)
|
| 1259 |
|
| 1260 |
+
# The transformer call is also correct, as latent_model_input and prompt_embeds are also CFG-batched.
|
| 1261 |
noise_pred = self.transformer(
|
| 1262 |
hidden_states=latent_model_input,
|
| 1263 |
timestep=timestep,
|