Yinhong Liu commited on
Commit
bf58491
·
1 Parent(s): 497e718

fix sd3 pipeline

Browse files
sid/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.89 kB). View file
 
sid/__pycache__/pipeline_output.cpython-312.pyc ADDED
Binary file (1.13 kB). View file
 
sid/__pycache__/pipeline_sid_flux.cpython-312.pyc ADDED
Binary file (46.2 kB). View file
 
sid/__pycache__/pipeline_sid_sana.cpython-312.pyc ADDED
Binary file (46.6 kB). View file
 
sid/__pycache__/pipeline_sid_sd3.cpython-312.pyc ADDED
Binary file (32.4 kB). View file
 
sid/pipeline_sid_flux.py CHANGED
@@ -0,0 +1,990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
31
+ from diffusers.models import AutoencoderKL, FluxTransformer2DModel
32
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
+ from .pipeline_output import SiDPipelineOutput
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+ EXAMPLE_DOC_STRING = """
57
+ Examples:
58
+ ```py
59
+ >>> import torch
60
+ >>> from diffusers import FluxPipeline
61
+
62
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
63
+ >>> pipe.to("cuda")
64
+ >>> prompt = "A cat holding a sign that says hello world"
65
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
66
+ >>> # Refer to the pipeline documentation for more details.
67
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
68
+ >>> image.save("flux.png")
69
+ ```
70
+ """
71
+
72
+
73
+ def calculate_shift(
74
+ image_seq_len,
75
+ base_seq_len: int = 256,
76
+ max_seq_len: int = 4096,
77
+ base_shift: float = 0.5,
78
+ max_shift: float = 1.15,
79
+ ):
80
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
81
+ b = base_shift - m * base_seq_len
82
+ mu = image_seq_len * m + b
83
+ return mu
84
+
85
+
86
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
87
+ def retrieve_timesteps(
88
+ scheduler,
89
+ num_inference_steps: Optional[int] = None,
90
+ device: Optional[Union[str, torch.device]] = None,
91
+ timesteps: Optional[List[int]] = None,
92
+ sigmas: Optional[List[float]] = None,
93
+ **kwargs,
94
+ ):
95
+ r"""
96
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
97
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
98
+
99
+ Args:
100
+ scheduler (`SchedulerMixin`):
101
+ The scheduler to get timesteps from.
102
+ num_inference_steps (`int`):
103
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
104
+ must be `None`.
105
+ device (`str` or `torch.device`, *optional*):
106
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
107
+ timesteps (`List[int]`, *optional*):
108
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
109
+ `num_inference_steps` and `sigmas` must be `None`.
110
+ sigmas (`List[float]`, *optional*):
111
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
112
+ `num_inference_steps` and `timesteps` must be `None`.
113
+
114
+ Returns:
115
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
116
+ second element is the number of inference steps.
117
+ """
118
+ if timesteps is not None and sigmas is not None:
119
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
120
+ if timesteps is not None:
121
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
122
+ if not accepts_timesteps:
123
+ raise ValueError(
124
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
125
+ f" timestep schedules. Please check whether you are using the correct scheduler."
126
+ )
127
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
128
+ timesteps = scheduler.timesteps
129
+ num_inference_steps = len(timesteps)
130
+ elif sigmas is not None:
131
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
132
+ if not accept_sigmas:
133
+ raise ValueError(
134
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
135
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
136
+ )
137
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
138
+ timesteps = scheduler.timesteps
139
+ num_inference_steps = len(timesteps)
140
+ else:
141
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ return timesteps, num_inference_steps
144
+
145
+
146
+ class SiDFluxPipeline(
147
+ DiffusionPipeline,
148
+ FluxLoraLoaderMixin,
149
+ FromSingleFileMixin,
150
+ TextualInversionLoaderMixin,
151
+ FluxIPAdapterMixin,
152
+ ):
153
+ r"""
154
+ The Flux pipeline for text-to-image generation.
155
+
156
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
157
+
158
+ Args:
159
+ transformer ([`FluxTransformer2DModel`]):
160
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
161
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
162
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
163
+ vae ([`AutoencoderKL`]):
164
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
165
+ text_encoder ([`CLIPTextModel`]):
166
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
167
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
168
+ text_encoder_2 ([`T5EncoderModel`]):
169
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
170
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
171
+ tokenizer (`CLIPTokenizer`):
172
+ Tokenizer of class
173
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
174
+ tokenizer_2 (`T5TokenizerFast`):
175
+ Second Tokenizer of class
176
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
177
+ """
178
+
179
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
180
+ _optional_components = ["image_encoder", "feature_extractor"]
181
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
182
+
183
+ def __init__(
184
+ self,
185
+ scheduler: FlowMatchEulerDiscreteScheduler,
186
+ vae: AutoencoderKL,
187
+ text_encoder: CLIPTextModel,
188
+ tokenizer: CLIPTokenizer,
189
+ text_encoder_2: T5EncoderModel,
190
+ tokenizer_2: T5TokenizerFast,
191
+ transformer: FluxTransformer2DModel,
192
+ image_encoder: CLIPVisionModelWithProjection = None,
193
+ feature_extractor: CLIPImageProcessor = None,
194
+ ):
195
+ super().__init__()
196
+
197
+ self.register_modules(
198
+ vae=vae,
199
+ text_encoder=text_encoder,
200
+ text_encoder_2=text_encoder_2,
201
+ tokenizer=tokenizer,
202
+ tokenizer_2=tokenizer_2,
203
+ transformer=transformer,
204
+ scheduler=scheduler,
205
+ image_encoder=image_encoder,
206
+ feature_extractor=feature_extractor,
207
+ )
208
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
209
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
210
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
211
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
212
+ self.tokenizer_max_length = (
213
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
214
+ )
215
+ self.default_sample_size = 128
216
+
217
+ def _get_t5_prompt_embeds(
218
+ self,
219
+ prompt: Union[str, List[str]] = None,
220
+ num_images_per_prompt: int = 1,
221
+ max_sequence_length: int = 512,
222
+ device: Optional[torch.device] = None,
223
+ dtype: Optional[torch.dtype] = None,
224
+ ):
225
+ device = device or self._execution_device
226
+ dtype = dtype or self.text_encoder.dtype
227
+
228
+ prompt = [prompt] if isinstance(prompt, str) else prompt
229
+ batch_size = len(prompt)
230
+
231
+ if isinstance(self, TextualInversionLoaderMixin):
232
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
233
+
234
+ text_inputs = self.tokenizer_2(
235
+ prompt,
236
+ padding="max_length",
237
+ max_length=max_sequence_length,
238
+ truncation=True,
239
+ return_length=False,
240
+ return_overflowing_tokens=False,
241
+ return_tensors="pt",
242
+ )
243
+ text_input_ids = text_inputs.input_ids
244
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
245
+
246
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
247
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
248
+ logger.warning(
249
+ "The following part of your input was truncated because `max_sequence_length` is set to "
250
+ f" {max_sequence_length} tokens: {removed_text}"
251
+ )
252
+
253
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
254
+
255
+ dtype = self.text_encoder_2.dtype
256
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
257
+
258
+ _, seq_len, _ = prompt_embeds.shape
259
+
260
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
261
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
262
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
263
+
264
+ return prompt_embeds
265
+
266
+ def _get_clip_prompt_embeds(
267
+ self,
268
+ prompt: Union[str, List[str]],
269
+ num_images_per_prompt: int = 1,
270
+ device: Optional[torch.device] = None,
271
+ ):
272
+ device = device or self._execution_device
273
+
274
+ prompt = [prompt] if isinstance(prompt, str) else prompt
275
+ batch_size = len(prompt)
276
+
277
+ if isinstance(self, TextualInversionLoaderMixin):
278
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
279
+
280
+ text_inputs = self.tokenizer(
281
+ prompt,
282
+ padding="max_length",
283
+ max_length=self.tokenizer_max_length,
284
+ truncation=True,
285
+ return_overflowing_tokens=False,
286
+ return_length=False,
287
+ return_tensors="pt",
288
+ )
289
+
290
+ text_input_ids = text_inputs.input_ids
291
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
292
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
293
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
294
+ logger.warning(
295
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
296
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
297
+ )
298
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
299
+
300
+ # Use pooled output of CLIPTextModel
301
+ prompt_embeds = prompt_embeds.pooler_output
302
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
303
+
304
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
305
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
306
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
307
+
308
+ return prompt_embeds
309
+
310
+ def encode_prompt(
311
+ self,
312
+ prompt: Union[str, List[str]],
313
+ prompt_2: Optional[Union[str, List[str]]] = None,
314
+ device: Optional[torch.device] = None,
315
+ num_images_per_prompt: int = 1,
316
+ prompt_embeds: Optional[torch.FloatTensor] = None,
317
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
318
+ max_sequence_length: int = 512,
319
+ lora_scale: Optional[float] = None,
320
+ ):
321
+ r"""
322
+
323
+ Args:
324
+ prompt (`str` or `List[str]`, *optional*):
325
+ prompt to be encoded
326
+ prompt_2 (`str` or `List[str]`, *optional*):
327
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
328
+ used in all text-encoders
329
+ device: (`torch.device`):
330
+ torch device
331
+ num_images_per_prompt (`int`):
332
+ number of images that should be generated per prompt
333
+ prompt_embeds (`torch.FloatTensor`, *optional*):
334
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
335
+ provided, text embeddings will be generated from `prompt` input argument.
336
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
337
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
338
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
339
+ lora_scale (`float`, *optional*):
340
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
341
+ """
342
+ device = device or self._execution_device
343
+
344
+ # set lora scale so that monkey patched LoRA
345
+ # function of text encoder can correctly access it
346
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
347
+ self._lora_scale = lora_scale
348
+
349
+ # dynamically adjust the LoRA scale
350
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
351
+ scale_lora_layers(self.text_encoder, lora_scale)
352
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
353
+ scale_lora_layers(self.text_encoder_2, lora_scale)
354
+
355
+ prompt = [prompt] if isinstance(prompt, str) else prompt
356
+
357
+ if prompt_embeds is None:
358
+ prompt_2 = prompt_2 or prompt
359
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
360
+
361
+ # We only use the pooled prompt output from the CLIPTextModel
362
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
363
+ prompt=prompt,
364
+ device=device,
365
+ num_images_per_prompt=num_images_per_prompt,
366
+ )
367
+ prompt_embeds = self._get_t5_prompt_embeds(
368
+ prompt=prompt_2,
369
+ num_images_per_prompt=num_images_per_prompt,
370
+ max_sequence_length=max_sequence_length,
371
+ device=device,
372
+ )
373
+
374
+ if self.text_encoder is not None:
375
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
376
+ # Retrieve the original scale by scaling back the LoRA layers
377
+ unscale_lora_layers(self.text_encoder, lora_scale)
378
+
379
+ if self.text_encoder_2 is not None:
380
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
381
+ # Retrieve the original scale by scaling back the LoRA layers
382
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
383
+
384
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
385
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
386
+
387
+ return prompt_embeds, pooled_prompt_embeds, text_ids
388
+
389
+ def encode_image(self, image, device, num_images_per_prompt):
390
+ dtype = next(self.image_encoder.parameters()).dtype
391
+
392
+ if not isinstance(image, torch.Tensor):
393
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
394
+
395
+ image = image.to(device=device, dtype=dtype)
396
+ image_embeds = self.image_encoder(image).image_embeds
397
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
398
+ return image_embeds
399
+
400
+ def prepare_ip_adapter_image_embeds(
401
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
402
+ ):
403
+ image_embeds = []
404
+ if ip_adapter_image_embeds is None:
405
+ if not isinstance(ip_adapter_image, list):
406
+ ip_adapter_image = [ip_adapter_image]
407
+
408
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
409
+ raise ValueError(
410
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
411
+ )
412
+
413
+ for single_ip_adapter_image in ip_adapter_image:
414
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
415
+ image_embeds.append(single_image_embeds[None, :])
416
+ else:
417
+ if not isinstance(ip_adapter_image_embeds, list):
418
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
419
+
420
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
421
+ raise ValueError(
422
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
423
+ )
424
+
425
+ for single_image_embeds in ip_adapter_image_embeds:
426
+ image_embeds.append(single_image_embeds)
427
+
428
+ ip_adapter_image_embeds = []
429
+ for single_image_embeds in image_embeds:
430
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
431
+ single_image_embeds = single_image_embeds.to(device=device)
432
+ ip_adapter_image_embeds.append(single_image_embeds)
433
+
434
+ return ip_adapter_image_embeds
435
+
436
+ def check_inputs(
437
+ self,
438
+ prompt,
439
+ prompt_2,
440
+ height,
441
+ width,
442
+ negative_prompt=None,
443
+ negative_prompt_2=None,
444
+ prompt_embeds=None,
445
+ negative_prompt_embeds=None,
446
+ pooled_prompt_embeds=None,
447
+ negative_pooled_prompt_embeds=None,
448
+ callback_on_step_end_tensor_inputs=None,
449
+ max_sequence_length=None,
450
+ ):
451
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
452
+ logger.warning(
453
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
454
+ )
455
+
456
+ if callback_on_step_end_tensor_inputs is not None and not all(
457
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
458
+ ):
459
+ raise ValueError(
460
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
461
+ )
462
+
463
+ if prompt is not None and prompt_embeds is not None:
464
+ raise ValueError(
465
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
466
+ " only forward one of the two."
467
+ )
468
+ elif prompt_2 is not None and prompt_embeds is not None:
469
+ raise ValueError(
470
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
471
+ " only forward one of the two."
472
+ )
473
+ elif prompt is None and prompt_embeds is None:
474
+ raise ValueError(
475
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
476
+ )
477
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
478
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
479
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
480
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
481
+
482
+ if negative_prompt is not None and negative_prompt_embeds is not None:
483
+ raise ValueError(
484
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
485
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
486
+ )
487
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
488
+ raise ValueError(
489
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
490
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
491
+ )
492
+
493
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
494
+ raise ValueError(
495
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
496
+ )
497
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
498
+ raise ValueError(
499
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
500
+ )
501
+
502
+ if max_sequence_length is not None and max_sequence_length > 512:
503
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
504
+
505
+ @staticmethod
506
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
507
+ latent_image_ids = torch.zeros(height, width, 3)
508
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
509
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
510
+
511
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
512
+
513
+ latent_image_ids = latent_image_ids.reshape(
514
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
515
+ )
516
+
517
+ return latent_image_ids.to(device=device, dtype=dtype)
518
+
519
+ @staticmethod
520
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
521
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
522
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
523
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
524
+
525
+ return latents
526
+
527
+ @staticmethod
528
+ def _unpack_latents(latents, height, width, vae_scale_factor):
529
+ batch_size, num_patches, channels = latents.shape
530
+
531
+ # VAE applies 8x compression on images but we must also account for packing which requires
532
+ # latent height and width to be divisible by 2.
533
+ height = 2 * (int(height) // (vae_scale_factor * 2))
534
+ width = 2 * (int(width) // (vae_scale_factor * 2))
535
+
536
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
537
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
538
+
539
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
540
+
541
+ return latents
542
+
543
+ def enable_vae_slicing(self):
544
+ r"""
545
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
546
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
547
+ """
548
+ self.vae.enable_slicing()
549
+
550
+ def disable_vae_slicing(self):
551
+ r"""
552
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
553
+ computing decoding in one step.
554
+ """
555
+ self.vae.disable_slicing()
556
+
557
+ def enable_vae_tiling(self):
558
+ r"""
559
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
560
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
561
+ processing larger images.
562
+ """
563
+ self.vae.enable_tiling()
564
+
565
+ def disable_vae_tiling(self):
566
+ r"""
567
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
568
+ computing decoding in one step.
569
+ """
570
+ self.vae.disable_tiling()
571
+
572
+ def prepare_latents(
573
+ self,
574
+ batch_size,
575
+ num_channels_latents,
576
+ height,
577
+ width,
578
+ dtype,
579
+ device,
580
+ generator,
581
+ latents=None,
582
+ ):
583
+ # VAE applies 8x compression on images but we must also account for packing which requires
584
+ # latent height and width to be divisible by 2.
585
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
586
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
587
+
588
+ shape = (batch_size, num_channels_latents, height, width)
589
+
590
+ if latents is not None:
591
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
592
+ return latents.to(device=device, dtype=dtype), latent_image_ids
593
+
594
+ if isinstance(generator, list) and len(generator) != batch_size:
595
+ raise ValueError(
596
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
597
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
598
+ )
599
+
600
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
601
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
602
+
603
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
604
+
605
+ return latents, latent_image_ids
606
+
607
+ @property
608
+ def guidance_scale(self):
609
+ return self._guidance_scale
610
+
611
+ @property
612
+ def joint_attention_kwargs(self):
613
+ return self._joint_attention_kwargs
614
+
615
+ @property
616
+ def num_timesteps(self):
617
+ return self._num_timesteps
618
+
619
+ @property
620
+ def current_timestep(self):
621
+ return self._current_timestep
622
+
623
+ @property
624
+ def interrupt(self):
625
+ return self._interrupt
626
+
627
+ @torch.no_grad()
628
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
629
+ def __call__(
630
+ self,
631
+ prompt: Union[str, List[str]] = None,
632
+ prompt_2: Optional[Union[str, List[str]]] = None,
633
+ negative_prompt: Union[str, List[str]] = None,
634
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
635
+ true_cfg_scale: float = 1.0,
636
+ height: Optional[int] = None,
637
+ width: Optional[int] = None,
638
+ num_inference_steps: int = 28,
639
+ sigmas: Optional[List[float]] = None,
640
+ guidance_scale: float = 3.5,
641
+ num_images_per_prompt: Optional[int] = 1,
642
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
643
+ latents: Optional[torch.FloatTensor] = None,
644
+ prompt_embeds: Optional[torch.FloatTensor] = None,
645
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
646
+ ip_adapter_image: Optional[PipelineImageInput] = None,
647
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
648
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
649
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
650
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
651
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
652
+ output_type: Optional[str] = "pil",
653
+ return_dict: bool = True,
654
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
655
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
656
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
657
+ max_sequence_length: int = 512,
658
+ ):
659
+ r"""
660
+ Function invoked when calling the pipeline for generation.
661
+
662
+ Args:
663
+ prompt (`str` or `List[str]`, *optional*):
664
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
665
+ instead.
666
+ prompt_2 (`str` or `List[str]`, *optional*):
667
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
668
+ will be used instead.
669
+ negative_prompt (`str` or `List[str]`, *optional*):
670
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
671
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
672
+ not greater than `1`).
673
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
674
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
675
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
676
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
677
+ True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
678
+ `negative_prompt` is provided.
679
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
680
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
681
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
682
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
683
+ num_inference_steps (`int`, *optional*, defaults to 50):
684
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
685
+ expense of slower inference.
686
+ sigmas (`List[float]`, *optional*):
687
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
688
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
689
+ will be used.
690
+ guidance_scale (`float`, *optional*, defaults to 3.5):
691
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
692
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
693
+
694
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
695
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
696
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
697
+ The number of images to generate per prompt.
698
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
699
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
700
+ to make generation deterministic.
701
+ latents (`torch.FloatTensor`, *optional*):
702
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
703
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
704
+ tensor will be generated by sampling using the supplied random `generator`.
705
+ prompt_embeds (`torch.FloatTensor`, *optional*):
706
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
707
+ provided, text embeddings will be generated from `prompt` input argument.
708
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
709
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
710
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
711
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
712
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
713
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
714
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
715
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
716
+ negative_ip_adapter_image:
717
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
718
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
719
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
720
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
721
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
722
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
723
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
724
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
725
+ argument.
726
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
727
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
728
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
729
+ input argument.
730
+ output_type (`str`, *optional*, defaults to `"pil"`):
731
+ The output format of the generate image. Choose between
732
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
733
+ return_dict (`bool`, *optional*, defaults to `True`):
734
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
735
+ joint_attention_kwargs (`dict`, *optional*):
736
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
737
+ `self.processor` in
738
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
739
+ callback_on_step_end (`Callable`, *optional*):
740
+ A function that calls at the end of each denoising steps during the inference. The function is called
741
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
742
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
743
+ `callback_on_step_end_tensor_inputs`.
744
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
745
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
746
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
747
+ `._callback_tensor_inputs` attribute of your pipeline class.
748
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
749
+
750
+ Examples:
751
+
752
+ Returns:
753
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
754
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
755
+ images.
756
+ """
757
+
758
+ height = height or self.default_sample_size * self.vae_scale_factor
759
+ width = width or self.default_sample_size * self.vae_scale_factor
760
+
761
+ # 1. Check inputs. Raise error if not correct
762
+ self.check_inputs(
763
+ prompt,
764
+ prompt_2,
765
+ height,
766
+ width,
767
+ negative_prompt=negative_prompt,
768
+ negative_prompt_2=negative_prompt_2,
769
+ prompt_embeds=prompt_embeds,
770
+ negative_prompt_embeds=negative_prompt_embeds,
771
+ pooled_prompt_embeds=pooled_prompt_embeds,
772
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
773
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
774
+ max_sequence_length=max_sequence_length,
775
+ )
776
+
777
+ self._guidance_scale = guidance_scale
778
+ self._joint_attention_kwargs = joint_attention_kwargs
779
+ self._current_timestep = None
780
+ self._interrupt = False
781
+
782
+ # 2. Define call parameters
783
+ if prompt is not None and isinstance(prompt, str):
784
+ batch_size = 1
785
+ elif prompt is not None and isinstance(prompt, list):
786
+ batch_size = len(prompt)
787
+ else:
788
+ batch_size = prompt_embeds.shape[0]
789
+
790
+ device = self._execution_device
791
+
792
+ lora_scale = (
793
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
794
+ )
795
+ has_neg_prompt = negative_prompt is not None or (
796
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
797
+ )
798
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
799
+ (
800
+ prompt_embeds,
801
+ pooled_prompt_embeds,
802
+ text_ids,
803
+ ) = self.encode_prompt(
804
+ prompt=prompt,
805
+ prompt_2=prompt_2,
806
+ prompt_embeds=prompt_embeds,
807
+ pooled_prompt_embeds=pooled_prompt_embeds,
808
+ device=device,
809
+ num_images_per_prompt=num_images_per_prompt,
810
+ max_sequence_length=max_sequence_length,
811
+ lora_scale=lora_scale,
812
+ )
813
+ if do_true_cfg:
814
+ (
815
+ negative_prompt_embeds,
816
+ negative_pooled_prompt_embeds,
817
+ negative_text_ids,
818
+ ) = self.encode_prompt(
819
+ prompt=negative_prompt,
820
+ prompt_2=negative_prompt_2,
821
+ prompt_embeds=negative_prompt_embeds,
822
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
823
+ device=device,
824
+ num_images_per_prompt=num_images_per_prompt,
825
+ max_sequence_length=max_sequence_length,
826
+ lora_scale=lora_scale,
827
+ )
828
+
829
+ # 4. Prepare latent variables
830
+ num_channels_latents = self.transformer.config.in_channels // 4
831
+ latents, latent_image_ids = self.prepare_latents(
832
+ batch_size * num_images_per_prompt,
833
+ num_channels_latents,
834
+ height,
835
+ width,
836
+ prompt_embeds.dtype,
837
+ device,
838
+ generator,
839
+ latents,
840
+ )
841
+
842
+ # 5. Prepare timesteps
843
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
844
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
845
+ sigmas = None
846
+ image_seq_len = latents.shape[1]
847
+ mu = calculate_shift(
848
+ image_seq_len,
849
+ self.scheduler.config.get("base_image_seq_len", 256),
850
+ self.scheduler.config.get("max_image_seq_len", 4096),
851
+ self.scheduler.config.get("base_shift", 0.5),
852
+ self.scheduler.config.get("max_shift", 1.15),
853
+ )
854
+ timesteps, num_inference_steps = retrieve_timesteps(
855
+ self.scheduler,
856
+ num_inference_steps,
857
+ device,
858
+ sigmas=sigmas,
859
+ mu=mu,
860
+ )
861
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
862
+ self._num_timesteps = len(timesteps)
863
+
864
+ # handle guidance
865
+ if self.transformer.config.guidance_embeds:
866
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
867
+ guidance = guidance.expand(latents.shape[0])
868
+ else:
869
+ guidance = None
870
+
871
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
872
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
873
+ ):
874
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
875
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
876
+
877
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
878
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
879
+ ):
880
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
881
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
882
+
883
+ if self.joint_attention_kwargs is None:
884
+ self._joint_attention_kwargs = {}
885
+
886
+ image_embeds = None
887
+ negative_image_embeds = None
888
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
889
+ image_embeds = self.prepare_ip_adapter_image_embeds(
890
+ ip_adapter_image,
891
+ ip_adapter_image_embeds,
892
+ device,
893
+ batch_size * num_images_per_prompt,
894
+ )
895
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
896
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
897
+ negative_ip_adapter_image,
898
+ negative_ip_adapter_image_embeds,
899
+ device,
900
+ batch_size * num_images_per_prompt,
901
+ )
902
+
903
+ # 6. Denoising loop
904
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
905
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
906
+ self.scheduler.set_begin_index(0)
907
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
908
+ for i, t in enumerate(timesteps):
909
+ if self.interrupt:
910
+ continue
911
+
912
+ self._current_timestep = t
913
+ if image_embeds is not None:
914
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
915
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
916
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
917
+
918
+ with self.transformer.cache_context("cond"):
919
+ noise_pred = self.transformer(
920
+ hidden_states=latents,
921
+ timestep=timestep / 1000,
922
+ guidance=guidance,
923
+ pooled_projections=pooled_prompt_embeds,
924
+ encoder_hidden_states=prompt_embeds,
925
+ txt_ids=text_ids,
926
+ img_ids=latent_image_ids,
927
+ joint_attention_kwargs=self.joint_attention_kwargs,
928
+ return_dict=False,
929
+ )[0]
930
+
931
+ if do_true_cfg:
932
+ if negative_image_embeds is not None:
933
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
934
+
935
+ with self.transformer.cache_context("uncond"):
936
+ neg_noise_pred = self.transformer(
937
+ hidden_states=latents,
938
+ timestep=timestep / 1000,
939
+ guidance=guidance,
940
+ pooled_projections=negative_pooled_prompt_embeds,
941
+ encoder_hidden_states=negative_prompt_embeds,
942
+ txt_ids=negative_text_ids,
943
+ img_ids=latent_image_ids,
944
+ joint_attention_kwargs=self.joint_attention_kwargs,
945
+ return_dict=False,
946
+ )[0]
947
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
948
+
949
+ # compute the previous noisy sample x_t -> x_t-1
950
+ latents_dtype = latents.dtype
951
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
952
+
953
+ if latents.dtype != latents_dtype:
954
+ if torch.backends.mps.is_available():
955
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
956
+ latents = latents.to(latents_dtype)
957
+
958
+ if callback_on_step_end is not None:
959
+ callback_kwargs = {}
960
+ for k in callback_on_step_end_tensor_inputs:
961
+ callback_kwargs[k] = locals()[k]
962
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
963
+
964
+ latents = callback_outputs.pop("latents", latents)
965
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
966
+
967
+ # call the callback, if provided
968
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
969
+ progress_bar.update()
970
+
971
+ if XLA_AVAILABLE:
972
+ xm.mark_step()
973
+
974
+ self._current_timestep = None
975
+
976
+ if output_type == "latent":
977
+ image = latents
978
+ else:
979
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
980
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
981
+ image = self.vae.decode(latents, return_dict=False)[0]
982
+ image = self.image_processor.postprocess(image, output_type=output_type)
983
+
984
+ # Offload all models
985
+ self.maybe_free_model_hooks()
986
+
987
+ if not return_dict:
988
+ return (image,)
989
+
990
+ return FluxPipelineOutput(images=image)
sid/pipeline_sid_sana.py CHANGED
@@ -0,0 +1,1011 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ import re
18
+ import urllib.parse as ul
19
+ import warnings
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
24
+
25
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from diffusers.image_processor import PixArtImageProcessor
27
+ from diffusers.loaders import SanaLoraLoaderMixin
28
+ from diffusers.models import AutoencoderDC, SanaTransformer2DModel
29
+ from diffusers.schedulers import DPMSolverMultistepScheduler
30
+ from diffusers.utils import (
31
+ BACKENDS_MAPPING,
32
+ USE_PEFT_BACKEND,
33
+ is_bs4_available,
34
+ is_ftfy_available,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.utils.torch_utils import get_device, is_torch_version, randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
+ from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
44
+ ASPECT_RATIO_512_BIN,
45
+ ASPECT_RATIO_1024_BIN,
46
+ )
47
+ from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
48
+ from .pipeline_output import SiDPipelineOutput
49
+
50
+
51
+ if is_torch_xla_available():
52
+ import torch_xla.core.xla_model as xm
53
+
54
+ XLA_AVAILABLE = True
55
+ else:
56
+ XLA_AVAILABLE = False
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+ if is_bs4_available():
61
+ from bs4 import BeautifulSoup
62
+
63
+ if is_ftfy_available():
64
+ import ftfy
65
+
66
+
67
+ ASPECT_RATIO_4096_BIN = {
68
+ "0.25": [2048.0, 8192.0],
69
+ "0.26": [2048.0, 7936.0],
70
+ "0.27": [2048.0, 7680.0],
71
+ "0.28": [2048.0, 7424.0],
72
+ "0.32": [2304.0, 7168.0],
73
+ "0.33": [2304.0, 6912.0],
74
+ "0.35": [2304.0, 6656.0],
75
+ "0.4": [2560.0, 6400.0],
76
+ "0.42": [2560.0, 6144.0],
77
+ "0.48": [2816.0, 5888.0],
78
+ "0.5": [2816.0, 5632.0],
79
+ "0.52": [2816.0, 5376.0],
80
+ "0.57": [3072.0, 5376.0],
81
+ "0.6": [3072.0, 5120.0],
82
+ "0.68": [3328.0, 4864.0],
83
+ "0.72": [3328.0, 4608.0],
84
+ "0.78": [3584.0, 4608.0],
85
+ "0.82": [3584.0, 4352.0],
86
+ "0.88": [3840.0, 4352.0],
87
+ "0.94": [3840.0, 4096.0],
88
+ "1.0": [4096.0, 4096.0],
89
+ "1.07": [4096.0, 3840.0],
90
+ "1.13": [4352.0, 3840.0],
91
+ "1.21": [4352.0, 3584.0],
92
+ "1.29": [4608.0, 3584.0],
93
+ "1.38": [4608.0, 3328.0],
94
+ "1.46": [4864.0, 3328.0],
95
+ "1.67": [5120.0, 3072.0],
96
+ "1.75": [5376.0, 3072.0],
97
+ "2.0": [5632.0, 2816.0],
98
+ "2.09": [5888.0, 2816.0],
99
+ "2.4": [6144.0, 2560.0],
100
+ "2.5": [6400.0, 2560.0],
101
+ "2.89": [6656.0, 2304.0],
102
+ "3.0": [6912.0, 2304.0],
103
+ "3.11": [7168.0, 2304.0],
104
+ "3.62": [7424.0, 2048.0],
105
+ "3.75": [7680.0, 2048.0],
106
+ "3.88": [7936.0, 2048.0],
107
+ "4.0": [8192.0, 2048.0],
108
+ }
109
+
110
+ EXAMPLE_DOC_STRING = """
111
+ Examples:
112
+ ```py
113
+ >>> import torch
114
+ >>> from diffusers import SanaPipeline
115
+
116
+ >>> pipe = SanaPipeline.from_pretrained(
117
+ ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
118
+ ... )
119
+ >>> pipe.to("cuda")
120
+ >>> pipe.text_encoder.to(torch.bfloat16)
121
+ >>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
122
+
123
+ >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
124
+ >>> image[0].save("output.png")
125
+ ```
126
+ """
127
+
128
+
129
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
130
+ def retrieve_timesteps(
131
+ scheduler,
132
+ num_inference_steps: Optional[int] = None,
133
+ device: Optional[Union[str, torch.device]] = None,
134
+ timesteps: Optional[List[int]] = None,
135
+ sigmas: Optional[List[float]] = None,
136
+ **kwargs,
137
+ ):
138
+ r"""
139
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
140
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
141
+
142
+ Args:
143
+ scheduler (`SchedulerMixin`):
144
+ The scheduler to get timesteps from.
145
+ num_inference_steps (`int`):
146
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
147
+ must be `None`.
148
+ device (`str` or `torch.device`, *optional*):
149
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
150
+ timesteps (`List[int]`, *optional*):
151
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
152
+ `num_inference_steps` and `sigmas` must be `None`.
153
+ sigmas (`List[float]`, *optional*):
154
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
155
+ `num_inference_steps` and `timesteps` must be `None`.
156
+
157
+ Returns:
158
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
159
+ second element is the number of inference steps.
160
+ """
161
+ if timesteps is not None and sigmas is not None:
162
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
163
+ if timesteps is not None:
164
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
165
+ if not accepts_timesteps:
166
+ raise ValueError(
167
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
168
+ f" timestep schedules. Please check whether you are using the correct scheduler."
169
+ )
170
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
171
+ timesteps = scheduler.timesteps
172
+ num_inference_steps = len(timesteps)
173
+ elif sigmas is not None:
174
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
175
+ if not accept_sigmas:
176
+ raise ValueError(
177
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
178
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
179
+ )
180
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
181
+ timesteps = scheduler.timesteps
182
+ num_inference_steps = len(timesteps)
183
+ else:
184
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
185
+ timesteps = scheduler.timesteps
186
+ return timesteps, num_inference_steps
187
+
188
+
189
+ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
190
+ r"""
191
+ Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629).
192
+ """
193
+
194
+ # fmt: off
195
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
196
+ # fmt: on
197
+
198
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
199
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
200
+
201
+ def __init__(
202
+ self,
203
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
204
+ text_encoder: Gemma2PreTrainedModel,
205
+ vae: AutoencoderDC,
206
+ transformer: SanaTransformer2DModel,
207
+ scheduler: DPMSolverMultistepScheduler,
208
+ ):
209
+ super().__init__()
210
+
211
+ self.register_modules(
212
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
213
+ )
214
+
215
+ self.vae_scale_factor = (
216
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
217
+ if hasattr(self, "vae") and self.vae is not None
218
+ else 32
219
+ )
220
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
221
+
222
+ def enable_vae_slicing(self):
223
+ r"""
224
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
225
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
226
+ """
227
+ self.vae.enable_slicing()
228
+
229
+ def disable_vae_slicing(self):
230
+ r"""
231
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
232
+ computing decoding in one step.
233
+ """
234
+ self.vae.disable_slicing()
235
+
236
+ def enable_vae_tiling(self):
237
+ r"""
238
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
239
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
240
+ processing larger images.
241
+ """
242
+ self.vae.enable_tiling()
243
+
244
+ def disable_vae_tiling(self):
245
+ r"""
246
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
247
+ computing decoding in one step.
248
+ """
249
+ self.vae.disable_tiling()
250
+
251
+ def _get_gemma_prompt_embeds(
252
+ self,
253
+ prompt: Union[str, List[str]],
254
+ device: torch.device,
255
+ dtype: torch.dtype,
256
+ clean_caption: bool = False,
257
+ max_sequence_length: int = 300,
258
+ complex_human_instruction: Optional[List[str]] = None,
259
+ ):
260
+ r"""
261
+ Encodes the prompt into text encoder hidden states.
262
+
263
+ Args:
264
+ prompt (`str` or `List[str]`, *optional*):
265
+ prompt to be encoded
266
+ device: (`torch.device`, *optional*):
267
+ torch device to place the resulting embeddings on
268
+ clean_caption (`bool`, defaults to `False`):
269
+ If `True`, the function will preprocess and clean the provided caption before encoding.
270
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
271
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
272
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
273
+ the prompt.
274
+ """
275
+ prompt = [prompt] if isinstance(prompt, str) else prompt
276
+
277
+ if getattr(self, "tokenizer", None) is not None:
278
+ self.tokenizer.padding_side = "right"
279
+
280
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
281
+
282
+ # prepare complex human instruction
283
+ if not complex_human_instruction:
284
+ max_length_all = max_sequence_length
285
+ else:
286
+ chi_prompt = "\n".join(complex_human_instruction)
287
+ prompt = [chi_prompt + p for p in prompt]
288
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
289
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
290
+
291
+ text_inputs = self.tokenizer(
292
+ prompt,
293
+ padding="max_length",
294
+ max_length=max_length_all,
295
+ truncation=True,
296
+ add_special_tokens=True,
297
+ return_tensors="pt",
298
+ )
299
+ text_input_ids = text_inputs.input_ids
300
+
301
+ prompt_attention_mask = text_inputs.attention_mask
302
+ prompt_attention_mask = prompt_attention_mask.to(device)
303
+
304
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
305
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
306
+
307
+ return prompt_embeds, prompt_attention_mask
308
+
309
+ def encode_prompt(
310
+ self,
311
+ prompt: Union[str, List[str]],
312
+ do_classifier_free_guidance: bool = True,
313
+ negative_prompt: str = "",
314
+ num_images_per_prompt: int = 1,
315
+ device: Optional[torch.device] = None,
316
+ prompt_embeds: Optional[torch.Tensor] = None,
317
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
318
+ prompt_attention_mask: Optional[torch.Tensor] = None,
319
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
320
+ clean_caption: bool = False,
321
+ max_sequence_length: int = 300,
322
+ complex_human_instruction: Optional[List[str]] = None,
323
+ lora_scale: Optional[float] = None,
324
+ ):
325
+ r"""
326
+ Encodes the prompt into text encoder hidden states.
327
+
328
+ Args:
329
+ prompt (`str` or `List[str]`, *optional*):
330
+ prompt to be encoded
331
+ negative_prompt (`str` or `List[str]`, *optional*):
332
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
333
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
334
+ PixArt-Alpha, this should be "".
335
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
336
+ whether to use classifier free guidance or not
337
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
338
+ number of images that should be generated per prompt
339
+ device: (`torch.device`, *optional*):
340
+ torch device to place the resulting embeddings on
341
+ prompt_embeds (`torch.Tensor`, *optional*):
342
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
343
+ provided, text embeddings will be generated from `prompt` input argument.
344
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
345
+ Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
346
+ clean_caption (`bool`, defaults to `False`):
347
+ If `True`, the function will preprocess and clean the provided caption before encoding.
348
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
349
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
350
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
351
+ the prompt.
352
+ """
353
+
354
+ if device is None:
355
+ device = self._execution_device
356
+
357
+ if self.text_encoder is not None:
358
+ dtype = self.text_encoder.dtype
359
+ else:
360
+ dtype = None
361
+
362
+ # set lora scale so that monkey patched LoRA
363
+ # function of text encoder can correctly access it
364
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
365
+ self._lora_scale = lora_scale
366
+
367
+ # dynamically adjust the LoRA scale
368
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
369
+ scale_lora_layers(self.text_encoder, lora_scale)
370
+
371
+ if prompt is not None and isinstance(prompt, str):
372
+ batch_size = 1
373
+ elif prompt is not None and isinstance(prompt, list):
374
+ batch_size = len(prompt)
375
+ else:
376
+ batch_size = prompt_embeds.shape[0]
377
+
378
+ if getattr(self, "tokenizer", None) is not None:
379
+ self.tokenizer.padding_side = "right"
380
+
381
+ # See Section 3.1. of the paper.
382
+ max_length = max_sequence_length
383
+ select_index = [0] + list(range(-max_length + 1, 0))
384
+
385
+ if prompt_embeds is None:
386
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
387
+ prompt=prompt,
388
+ device=device,
389
+ dtype=dtype,
390
+ clean_caption=clean_caption,
391
+ max_sequence_length=max_sequence_length,
392
+ complex_human_instruction=complex_human_instruction,
393
+ )
394
+
395
+ prompt_embeds = prompt_embeds[:, select_index]
396
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
397
+
398
+ bs_embed, seq_len, _ = prompt_embeds.shape
399
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
400
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
401
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
402
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
403
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
404
+
405
+ # get unconditional embeddings for classifier free guidance
406
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
407
+ negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
408
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
409
+ prompt=negative_prompt,
410
+ device=device,
411
+ dtype=dtype,
412
+ clean_caption=clean_caption,
413
+ max_sequence_length=max_sequence_length,
414
+ complex_human_instruction=False,
415
+ )
416
+
417
+ if do_classifier_free_guidance:
418
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
419
+ seq_len = negative_prompt_embeds.shape[1]
420
+
421
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
422
+
423
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
424
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
425
+
426
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
427
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
428
+ else:
429
+ negative_prompt_embeds = None
430
+ negative_prompt_attention_mask = None
431
+
432
+ if self.text_encoder is not None:
433
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
434
+ # Retrieve the original scale by scaling back the LoRA layers
435
+ unscale_lora_layers(self.text_encoder, lora_scale)
436
+
437
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
438
+
439
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
440
+ def prepare_extra_step_kwargs(self, generator, eta):
441
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
442
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
443
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
444
+ # and should be between [0, 1]
445
+
446
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
447
+ extra_step_kwargs = {}
448
+ if accepts_eta:
449
+ extra_step_kwargs["eta"] = eta
450
+
451
+ # check if the scheduler accepts generator
452
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
453
+ if accepts_generator:
454
+ extra_step_kwargs["generator"] = generator
455
+ return extra_step_kwargs
456
+
457
+ def check_inputs(
458
+ self,
459
+ prompt,
460
+ height,
461
+ width,
462
+ callback_on_step_end_tensor_inputs=None,
463
+ negative_prompt=None,
464
+ prompt_embeds=None,
465
+ negative_prompt_embeds=None,
466
+ prompt_attention_mask=None,
467
+ negative_prompt_attention_mask=None,
468
+ ):
469
+ if height % 32 != 0 or width % 32 != 0:
470
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
471
+
472
+ if callback_on_step_end_tensor_inputs is not None and not all(
473
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
474
+ ):
475
+ raise ValueError(
476
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
477
+ )
478
+
479
+ if prompt is not None and prompt_embeds is not None:
480
+ raise ValueError(
481
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
482
+ " only forward one of the two."
483
+ )
484
+ elif prompt is None and prompt_embeds is None:
485
+ raise ValueError(
486
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
487
+ )
488
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
489
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
490
+
491
+ if prompt is not None and negative_prompt_embeds is not None:
492
+ raise ValueError(
493
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
494
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
495
+ )
496
+
497
+ if negative_prompt is not None and negative_prompt_embeds is not None:
498
+ raise ValueError(
499
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
500
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
501
+ )
502
+
503
+ if prompt_embeds is not None and prompt_attention_mask is None:
504
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
505
+
506
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
507
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
508
+
509
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
510
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
511
+ raise ValueError(
512
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
513
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
514
+ f" {negative_prompt_embeds.shape}."
515
+ )
516
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
517
+ raise ValueError(
518
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
519
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
520
+ f" {negative_prompt_attention_mask.shape}."
521
+ )
522
+
523
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
524
+ def _text_preprocessing(self, text, clean_caption=False):
525
+ if clean_caption and not is_bs4_available():
526
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
527
+ logger.warning("Setting `clean_caption` to False...")
528
+ clean_caption = False
529
+
530
+ if clean_caption and not is_ftfy_available():
531
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
532
+ logger.warning("Setting `clean_caption` to False...")
533
+ clean_caption = False
534
+
535
+ if not isinstance(text, (tuple, list)):
536
+ text = [text]
537
+
538
+ def process(text: str):
539
+ if clean_caption:
540
+ text = self._clean_caption(text)
541
+ text = self._clean_caption(text)
542
+ else:
543
+ text = text.lower().strip()
544
+ return text
545
+
546
+ return [process(t) for t in text]
547
+
548
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
549
+ def _clean_caption(self, caption):
550
+ caption = str(caption)
551
+ caption = ul.unquote_plus(caption)
552
+ caption = caption.strip().lower()
553
+ caption = re.sub("<person>", "person", caption)
554
+ # urls:
555
+ caption = re.sub(
556
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
557
+ "",
558
+ caption,
559
+ ) # regex for urls
560
+ caption = re.sub(
561
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
562
+ "",
563
+ caption,
564
+ ) # regex for urls
565
+ # html:
566
+ caption = BeautifulSoup(caption, features="html.parser").text
567
+
568
+ # @<nickname>
569
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
570
+
571
+ # 31C0—31EF CJK Strokes
572
+ # 31F0—31FF Katakana Phonetic Extensions
573
+ # 3200—32FF Enclosed CJK Letters and Months
574
+ # 3300—33FF CJK Compatibility
575
+ # 3400—4DBF CJK Unified Ideographs Extension A
576
+ # 4DC0—4DFF Yijing Hexagram Symbols
577
+ # 4E00—9FFF CJK Unified Ideographs
578
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
579
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
580
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
581
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
582
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
583
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
584
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
585
+ #######################################################
586
+
587
+ # все виды тире / all types of dash --> "-"
588
+ caption = re.sub(
589
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
590
+ "-",
591
+ caption,
592
+ )
593
+
594
+ # кавычки к одному стандарту
595
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
596
+ caption = re.sub(r"[‘’]", "'", caption)
597
+
598
+ # &quot;
599
+ caption = re.sub(r"&quot;?", "", caption)
600
+ # &amp
601
+ caption = re.sub(r"&amp", "", caption)
602
+
603
+ # ip addresses:
604
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
605
+
606
+ # article ids:
607
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
608
+
609
+ # \n
610
+ caption = re.sub(r"\\n", " ", caption)
611
+
612
+ # "#123"
613
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
614
+ # "#12345.."
615
+ caption = re.sub(r"#\d{5,}\b", "", caption)
616
+ # "123456.."
617
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
618
+ # filenames:
619
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
620
+
621
+ #
622
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
623
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
624
+
625
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
626
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
627
+
628
+ # this-is-my-cute-cat / this_is_my_cute_cat
629
+ regex2 = re.compile(r"(?:\-|\_)")
630
+ if len(re.findall(regex2, caption)) > 3:
631
+ caption = re.sub(regex2, " ", caption)
632
+
633
+ caption = ftfy.fix_text(caption)
634
+ caption = html.unescape(html.unescape(caption))
635
+
636
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
637
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
638
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
639
+
640
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
641
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
642
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
643
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
644
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
645
+
646
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
647
+
648
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
649
+
650
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
651
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
652
+ caption = re.sub(r"\s+", " ", caption)
653
+
654
+ caption.strip()
655
+
656
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
657
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
658
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
659
+ caption = re.sub(r"^\.\S+$", "", caption)
660
+
661
+ return caption.strip()
662
+
663
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
664
+ if latents is not None:
665
+ return latents.to(device=device, dtype=dtype)
666
+
667
+ shape = (
668
+ batch_size,
669
+ num_channels_latents,
670
+ int(height) // self.vae_scale_factor,
671
+ int(width) // self.vae_scale_factor,
672
+ )
673
+ if isinstance(generator, list) and len(generator) != batch_size:
674
+ raise ValueError(
675
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
676
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
677
+ )
678
+
679
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
680
+ return latents
681
+
682
+ @property
683
+ def guidance_scale(self):
684
+ return self._guidance_scale
685
+
686
+ @property
687
+ def attention_kwargs(self):
688
+ return self._attention_kwargs
689
+
690
+ @property
691
+ def do_classifier_free_guidance(self):
692
+ return self._guidance_scale > 1.0
693
+
694
+ @property
695
+ def num_timesteps(self):
696
+ return self._num_timesteps
697
+
698
+ @property
699
+ def interrupt(self):
700
+ return self._interrupt
701
+
702
+ @torch.no_grad()
703
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
704
+ def __call__(
705
+ self,
706
+ prompt: Union[str, List[str]] = None,
707
+ negative_prompt: str = "",
708
+ num_inference_steps: int = 20,
709
+ timesteps: List[int] = None,
710
+ sigmas: List[float] = None,
711
+ guidance_scale: float = 4.5,
712
+ num_images_per_prompt: Optional[int] = 1,
713
+ height: int = 1024,
714
+ width: int = 1024,
715
+ eta: float = 0.0,
716
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
717
+ latents: Optional[torch.Tensor] = None,
718
+ prompt_embeds: Optional[torch.Tensor] = None,
719
+ prompt_attention_mask: Optional[torch.Tensor] = None,
720
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
721
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
722
+ output_type: Optional[str] = "pil",
723
+ return_dict: bool = True,
724
+ clean_caption: bool = False,
725
+ use_resolution_binning: bool = True,
726
+ attention_kwargs: Optional[Dict[str, Any]] = None,
727
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
728
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
729
+ max_sequence_length: int = 300,
730
+ complex_human_instruction: List[str] = [
731
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
732
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
733
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
734
+ "Here are examples of how to transform or refine prompts:",
735
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
736
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
737
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
738
+ "User Prompt: ",
739
+ ],
740
+ ) -> Union[SiDPipelineOutput, Tuple]:
741
+ """
742
+ Function invoked when calling the pipeline for generation.
743
+
744
+ Args:
745
+ prompt (`str` or `List[str]`, *optional*):
746
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
747
+ instead.
748
+ negative_prompt (`str` or `List[str]`, *optional*):
749
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
750
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
751
+ less than `1`).
752
+ num_inference_steps (`int`, *optional*, defaults to 20):
753
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
754
+ expense of slower inference.
755
+ timesteps (`List[int]`, *optional*):
756
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
757
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
758
+ passed will be used. Must be in descending order.
759
+ sigmas (`List[float]`, *optional*):
760
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
761
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
762
+ will be used.
763
+ guidance_scale (`float`, *optional*, defaults to 4.5):
764
+ Guidance scale as defined in [Classifier-Free Diffusion
765
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
766
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
767
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
768
+ the text `prompt`, usually at the expense of lower image quality.
769
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
770
+ The number of images to generate per prompt.
771
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
772
+ The height in pixels of the generated image.
773
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
774
+ The width in pixels of the generated image.
775
+ eta (`float`, *optional*, defaults to 0.0):
776
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
777
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
778
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
779
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
780
+ to make generation deterministic.
781
+ latents (`torch.Tensor`, *optional*):
782
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
783
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
784
+ tensor will ge generated by sampling using the supplied random `generator`.
785
+ prompt_embeds (`torch.Tensor`, *optional*):
786
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
787
+ provided, text embeddings will be generated from `prompt` input argument.
788
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
789
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
790
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
791
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
792
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
793
+ Pre-generated attention mask for negative text embeddings.
794
+ output_type (`str`, *optional*, defaults to `"pil"`):
795
+ The output format of the generate image. Choose between
796
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
797
+ return_dict (`bool`, *optional*, defaults to `True`):
798
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
799
+ attention_kwargs:
800
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
801
+ `self.processor` in
802
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
803
+ clean_caption (`bool`, *optional*, defaults to `True`):
804
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
805
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
806
+ prompt.
807
+ use_resolution_binning (`bool` defaults to `True`):
808
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
809
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
810
+ the requested resolution. Useful for generating non-square images.
811
+ callback_on_step_end (`Callable`, *optional*):
812
+ A function that calls at the end of each denoising steps during the inference. The function is called
813
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
814
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
815
+ `callback_on_step_end_tensor_inputs`.
816
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
817
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
818
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
819
+ `._callback_tensor_inputs` attribute of your pipeline class.
820
+ max_sequence_length (`int` defaults to `300`):
821
+ Maximum sequence length to use with the `prompt`.
822
+ complex_human_instruction (`List[str]`, *optional*):
823
+ Instructions for complex human attention:
824
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
825
+
826
+ Examples:
827
+
828
+ Returns:
829
+ [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
830
+ If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
831
+ otherwise a `tuple` is returned where the first element is a list with the generated images
832
+ """
833
+
834
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
835
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
836
+
837
+ # 1. Check inputs. Raise error if not correct
838
+ if use_resolution_binning:
839
+ if self.transformer.config.sample_size == 128:
840
+ aspect_ratio_bin = ASPECT_RATIO_4096_BIN
841
+ elif self.transformer.config.sample_size == 64:
842
+ aspect_ratio_bin = ASPECT_RATIO_2048_BIN
843
+ elif self.transformer.config.sample_size == 32:
844
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
845
+ elif self.transformer.config.sample_size == 16:
846
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
847
+ else:
848
+ raise ValueError("Invalid sample size")
849
+ orig_height, orig_width = height, width
850
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
851
+
852
+ self.check_inputs(
853
+ prompt,
854
+ height,
855
+ width,
856
+ callback_on_step_end_tensor_inputs,
857
+ negative_prompt,
858
+ prompt_embeds,
859
+ negative_prompt_embeds,
860
+ prompt_attention_mask,
861
+ negative_prompt_attention_mask,
862
+ )
863
+
864
+ self._guidance_scale = guidance_scale
865
+ self._attention_kwargs = attention_kwargs
866
+ self._interrupt = False
867
+
868
+ # 2. Default height and width to transformer
869
+ if prompt is not None and isinstance(prompt, str):
870
+ batch_size = 1
871
+ elif prompt is not None and isinstance(prompt, list):
872
+ batch_size = len(prompt)
873
+ else:
874
+ batch_size = prompt_embeds.shape[0]
875
+
876
+ device = self._execution_device
877
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
878
+
879
+ # 3. Encode input prompt
880
+ (
881
+ prompt_embeds,
882
+ prompt_attention_mask,
883
+ negative_prompt_embeds,
884
+ negative_prompt_attention_mask,
885
+ ) = self.encode_prompt(
886
+ prompt,
887
+ self.do_classifier_free_guidance,
888
+ negative_prompt=negative_prompt,
889
+ num_images_per_prompt=num_images_per_prompt,
890
+ device=device,
891
+ prompt_embeds=prompt_embeds,
892
+ negative_prompt_embeds=negative_prompt_embeds,
893
+ prompt_attention_mask=prompt_attention_mask,
894
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
895
+ clean_caption=clean_caption,
896
+ max_sequence_length=max_sequence_length,
897
+ complex_human_instruction=complex_human_instruction,
898
+ lora_scale=lora_scale,
899
+ )
900
+ if self.do_classifier_free_guidance:
901
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
902
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
903
+
904
+ # 4. Prepare timesteps
905
+ timesteps, num_inference_steps = retrieve_timesteps(
906
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
907
+ )
908
+
909
+ # 5. Prepare latents.
910
+ latent_channels = self.transformer.config.in_channels
911
+ latents = self.prepare_latents(
912
+ batch_size * num_images_per_prompt,
913
+ latent_channels,
914
+ height,
915
+ width,
916
+ torch.float32,
917
+ device,
918
+ generator,
919
+ latents,
920
+ )
921
+
922
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
923
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
924
+
925
+ # 7. Denoising loop
926
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
927
+ self._num_timesteps = len(timesteps)
928
+
929
+ transformer_dtype = self.transformer.dtype
930
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
931
+ for i, t in enumerate(timesteps):
932
+ if self.interrupt:
933
+ continue
934
+
935
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
936
+
937
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
938
+ timestep = t.expand(latent_model_input.shape[0])
939
+ timestep = timestep * self.transformer.config.timestep_scale
940
+
941
+ # predict noise model_output
942
+ noise_pred = self.transformer(
943
+ latent_model_input.to(dtype=transformer_dtype),
944
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
945
+ encoder_attention_mask=prompt_attention_mask,
946
+ timestep=timestep,
947
+ return_dict=False,
948
+ attention_kwargs=self.attention_kwargs,
949
+ )[0]
950
+ noise_pred = noise_pred.float()
951
+
952
+ # perform guidance
953
+ if self.do_classifier_free_guidance:
954
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
955
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
956
+
957
+ # learned sigma
958
+ if self.transformer.config.out_channels // 2 == latent_channels:
959
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
960
+
961
+ # compute previous image: x_t -> x_t-1
962
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
963
+
964
+ if callback_on_step_end is not None:
965
+ callback_kwargs = {}
966
+ for k in callback_on_step_end_tensor_inputs:
967
+ callback_kwargs[k] = locals()[k]
968
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
969
+
970
+ latents = callback_outputs.pop("latents", latents)
971
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
972
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
973
+
974
+ # call the callback, if provided
975
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
976
+ progress_bar.update()
977
+
978
+ if XLA_AVAILABLE:
979
+ xm.mark_step()
980
+
981
+ if output_type == "latent":
982
+ image = latents
983
+ else:
984
+ latents = latents.to(self.vae.dtype)
985
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
986
+ oom_error = (
987
+ torch.OutOfMemoryError
988
+ if is_torch_version(">=", "2.5.0")
989
+ else torch_accelerator_module.OutOfMemoryError
990
+ )
991
+ try:
992
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
993
+ except oom_error as e:
994
+ warnings.warn(
995
+ f"{e}. \n"
996
+ f"Try to use VAE tiling for large images. For example: \n"
997
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
998
+ )
999
+ if use_resolution_binning:
1000
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
1001
+
1002
+ if not output_type == "latent":
1003
+ image = self.image_processor.postprocess(image, output_type=output_type)
1004
+
1005
+ # Offload all models
1006
+ self.maybe_free_model_hooks()
1007
+
1008
+ if not return_dict:
1009
+ return (image,)
1010
+
1011
+ return SiDPipelineOutput(images=image)
sid/pipeline_sid_sd3.py CHANGED
@@ -54,23 +54,6 @@ else:
54
 
55
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
 
57
- EXAMPLE_DOC_STRING = """
58
- Examples:
59
- ```py
60
- >>> import torch
61
- >>> from diffusers import StableDiffusion3Pipeline
62
-
63
- >>> pipe = StableDiffusion3Pipeline.from_pretrained(
64
- ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
65
- ... )
66
- >>> pipe.to("cuda")
67
- >>> prompt = "A cat holding a sign that says hello world"
68
- >>> image = pipe(prompt).images[0]
69
- >>> image.save("sd3.png")
70
- ```
71
- """
72
-
73
-
74
  # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
75
  def calculate_shift(
76
  image_seq_len,
@@ -680,7 +663,6 @@ class SiDSD3Pipeline(
680
  super().enable_sequential_cpu_offload(*args, **kwargs)
681
 
682
  @torch.no_grad()
683
- @replace_example_docstring(EXAMPLE_DOC_STRING)
684
  def __call__(
685
  self,
686
  prompt: Union[str, List[str]] = None,
 
54
 
55
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
58
  def calculate_shift(
59
  image_seq_len,
 
663
  super().enable_sequential_cpu_offload(*args, **kwargs)
664
 
665
  @torch.no_grad()
 
666
  def __call__(
667
  self,
668
  prompt: Union[str, List[str]] = None,