Update pipeline_stable_diffusion_3_ipa.py
Browse files
pipeline_stable_diffusion_3_ipa.py
CHANGED
|
@@ -965,7 +965,17 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 965 |
|
| 966 |
# ipa
|
| 967 |
clip_image=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
ipadapter_scale=1.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 969 |
):
|
| 970 |
r"""
|
| 971 |
Function invoked when calling the pipeline for generation.
|
|
@@ -1126,10 +1136,42 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
| 1126 |
if self.do_classifier_free_guidance:
|
| 1127 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1128 |
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 1129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1130 |
# 3. prepare clip emb
|
| 1131 |
clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
|
| 1132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1133 |
|
| 1134 |
# 4. Prepare timesteps
|
| 1135 |
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
|
|
|
| 965 |
|
| 966 |
# ipa
|
| 967 |
clip_image=None,
|
| 968 |
+
clip_image_2=None,
|
| 969 |
+
clip_image_3=None,
|
| 970 |
+
clip_image_4=None,
|
| 971 |
+
clip_image_5=None,
|
| 972 |
+
text_scale=1.0,
|
| 973 |
ipadapter_scale=1.0,
|
| 974 |
+
scale_1=1.0,
|
| 975 |
+
scale_2=1.0,
|
| 976 |
+
scale_3=1.0,
|
| 977 |
+
scale_4=1.0,
|
| 978 |
+
scale_5=1.0,
|
| 979 |
):
|
| 980 |
r"""
|
| 981 |
Function invoked when calling the pipeline for generation.
|
|
|
|
| 1136 |
if self.do_classifier_free_guidance:
|
| 1137 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 1138 |
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 1139 |
+
|
| 1140 |
+
prompt_embeds = prompt_embeds * text_scale
|
| 1141 |
+
|
| 1142 |
+
image_prompt_embeds_list = []
|
| 1143 |
+
|
| 1144 |
# 3. prepare clip emb
|
| 1145 |
clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
|
| 1146 |
+
clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
|
| 1147 |
+
image_prompt_embeds_list.append(clip_image_embeds_1)
|
| 1148 |
+
|
| 1149 |
+
if clip_image_2 != None:
|
| 1150 |
+
print('Using secondary image.')
|
| 1151 |
+
clip_image_2 = clip_image_2.resize((max(clip_image.size), max(clip_image.size)))
|
| 1152 |
+
image_prompt_embeds_2 = self.encode_clip_image_emb(clip_image, device, dtype)
|
| 1153 |
+
image_prompt_embeds_2 = image_prompt_embeds_2 * scale_2
|
| 1154 |
+
image_prompt_embeds_list.append(image_prompt_embeds_2)
|
| 1155 |
+
if clip_image_3 != None:
|
| 1156 |
+
print('Using tertiary image.')
|
| 1157 |
+
clip_image_3 = clip_image_3.resize((max(clip_image.size), max(clip_image.size)))
|
| 1158 |
+
image_prompt_embeds_3 = self.encode_clip_image_emb(clip_image, device, dtype)
|
| 1159 |
+
image_prompt_embeds_3 = image_prompt_embeds_3 * scale_3
|
| 1160 |
+
image_prompt_embeds_list.append(image_prompt_embeds_3)
|
| 1161 |
+
if clip_image_4 != None:
|
| 1162 |
+
print('Using quaternary image.')
|
| 1163 |
+
clip_image_4 = clip_image_4.resize((max(clip_image.size), max(clip_image.size)))
|
| 1164 |
+
image_prompt_embeds_4 = self.encode_clip_image_emb(clip_image, device, dtype)
|
| 1165 |
+
image_prompt_embeds_4 = image_prompt_embeds_4 * scale_4
|
| 1166 |
+
image_prompt_embeds_list.append(image_prompt_embeds_4)
|
| 1167 |
+
if clip_image_5 != None:
|
| 1168 |
+
print('Using quinary image.')
|
| 1169 |
+
clip_image_5 = clip_image_5.resize((max(clip_image.size), max(clip_image.size)))
|
| 1170 |
+
image_prompt_embeds_5 = self.encode_clip_image_emb(clip_image, device, dtype)
|
| 1171 |
+
image_prompt_embeds_5 = image_prompt_embeds_5 * scale_5
|
| 1172 |
+
image_prompt_embeds_list.append(image_prompt_embeds_5)
|
| 1173 |
+
|
| 1174 |
+
clip_image_embeds = torch.cat(image_prompt_embeds_list).mean(dim=0).unsqueeze(0)
|
| 1175 |
|
| 1176 |
# 4. Prepare timesteps
|
| 1177 |
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|