# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Optional, Tuple, Union import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.nn.utils.rnn import pad_sequence from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from diffusers.models.attention import FeedForward, JointTransformerBlock, _chunked_feed_forward from diffusers.models.attention_processor import ( Attention, AttentionProcessor, FusedJointAttnProcessor2_0, JointAttnProcessor2_0, ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from diffusers.models.modeling_outputs import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class CustomJointAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, *args, **kwargs, ) -> torch.FloatTensor: residual = hidden_states batch_size = hidden_states.shape[0] # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # `context` projections. if encoder_hidden_states is not None: ctx_len = encoder_hidden_states.shape[1] encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) if attention_mask is not None: # import pdb; pdb.set_trace() encoder_attention_mask = torch.ones( batch_size, ctx_len, dtype=torch.bool, device=hidden_states.device) attention_mask = torch.cat([attention_mask, encoder_attention_mask], dim=1) # import pdb; pdb.set_trace() if attention_mask is not None: attention_mask = attention_mask[:, None] * attention_mask[..., None] # bsz, seqlen, seqlen indices = range(attention_mask.shape[1]) attention_mask[:, indices, indices] = True attention_mask = attention_mask[:, None] hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: # import pdb; pdb.set_trace() # Split the attention outputs. hidden_states, encoder_hidden_states = ( hidden_states[:, : residual.shape[1]], hidden_states[:, residual.shape[1] :], ) if not attn.context_pre_only: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: return hidden_states, encoder_hidden_states else: return hidden_states class CustomJointTransformerBlock(JointTransformerBlock): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.attn.set_processor(CustomJointAttnProcessor2_0()) if self.attn2 is not None: self.attn2.set_processor(CustomJointAttnProcessor2_0()) def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ): joint_attention_kwargs = joint_attention_kwargs or {} if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( hidden_states, emb=temb ) else: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) if self.context_pre_only: norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) else: norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, attention_mask=attention_mask, encoder_hidden_states=norm_encoder_hidden_states, **joint_attention_kwargs, ) # Process attention outputs for the `hidden_states`. attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output if self.use_dual_attention: attn_output2 = self.attn2(hidden_states=norm_hidden_states2, attention_mask=attention_mask, **joint_attention_kwargs) attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 hidden_states = hidden_states + attn_output2 norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output # Process attention outputs for the `encoder_hidden_states`. if self.context_pre_only: encoder_hidden_states = None else: context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory context_ff_output = _chunked_feed_forward( self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size ) else: context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output return encoder_hidden_states, hidden_states @maybe_allow_in_graph class SD3SingleTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, ): super().__init__() self.norm1 = AdaLayerNormZero(dim) self.attn = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, processor=JointAttnProcessor2_0(), eps=1e-6, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): # 1. Attention norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None) attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output # 2. Feed Forward norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output return hidden_states class SD3Transformer2DModel( ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin ): """ The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). Parameters: sample_size (`int`, defaults to `128`): The width/height of the latents. This is fixed during training since it is used to learn a number of position embeddings. patch_size (`int`, defaults to `2`): Patch size to turn the input data into small patches. in_channels (`int`, defaults to `16`): The number of latent channels in the input. num_layers (`int`, defaults to `18`): The number of layers of transformer blocks to use. attention_head_dim (`int`, defaults to `64`): The number of channels in each head. num_attention_heads (`int`, defaults to `18`): The number of heads to use for multi-head attention. joint_attention_dim (`int`, defaults to `4096`): The embedding dimension to use for joint text-image attention. caption_projection_dim (`int`, defaults to `1152`): The embedding dimension of caption embeddings. pooled_projection_dim (`int`, defaults to `2048`): The embedding dimension of pooled text projections. out_channels (`int`, defaults to `16`): The number of latent channels in the output. pos_embed_max_size (`int`, defaults to `96`): The maximum latent height/width of positional embeddings. dual_attention_layers (`Tuple[int, ...]`, defaults to `()`): The number of dual-stream transformer blocks to use. qk_norm (`str`, *optional*, defaults to `None`): The normalization to use for query and key in the attention layer. If `None`, no normalization is used. """ _supports_gradient_checkpointing = True _no_split_modules = ["JointTransformerBlock", "CustomJointTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config def __init__( self, sample_size: int = 128, patch_size: int = 2, in_channels: int = 16, num_layers: int = 18, attention_head_dim: int = 64, num_attention_heads: int = 18, joint_attention_dim: int = 4096, caption_projection_dim: int = 1152, pooled_projection_dim: int = 2048, out_channels: int = 16, pos_embed_max_size: int = 96, dual_attention_layers: Tuple[ int, ... ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 qk_norm: Optional[str] = None, ): super().__init__() self.out_channels = out_channels if out_channels is not None else in_channels self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=self.inner_dim, pos_embed_max_size=pos_embed_max_size, # hard-code for now. ) self.time_text_embed = CombinedTimestepTextProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) self.transformer_blocks = nn.ModuleList( [ CustomJointTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, context_pre_only=i == num_layers - 1, qk_norm=qk_norm, use_dual_attention=True if i in dual_attention_layers else False, ) for i in range(num_layers) ] ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking def disable_forward_chunking(self): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, None, 0) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedJointAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, cond_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, block_controlnet_hidden_states: List = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, skip_layers: Optional[List[int]] = None, ) -> Union[torch.Tensor, Transformer2DModelOutput]: """ The [`SD3Transformer2DModel`] forward method. Args: hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep (`torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. skip_layers (`list` of `int`, *optional*): A list of layer indices to skip during the forward pass. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) latent_sizes = [hs.shape[-2:] for hs in hidden_states] bsz = len(hidden_states) hidden_states_list = [] for idx in range(bsz): hidden_states_per_sample = self.pos_embed(hidden_states[idx][None])[0] if cond_hidden_states is not None: for ref in cond_hidden_states[idx]: hidden_states_per_sample = torch.cat( [hidden_states_per_sample, self.pos_embed(ref[None])[0]]) hidden_states_list.append(hidden_states_per_sample) max_len = max([len(hs) for hs in hidden_states_list]) attention_mask = torch.zeros(bsz, max_len, dtype=torch.bool, device=self.device) for i, hs in enumerate(hidden_states_list): attention_mask[i, :len(hs)] = True # right padding # import pdb; pdb.set_trace() hidden_states = pad_sequence(hidden_states_list, batch_first=True, padding_value=0.0, padding_side='right') temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep) joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb) for index_block, block in enumerate(self.transformer_blocks): # Skip specified layers is_skip = True if skip_layers is not None and index_block in skip_layers else False if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, temb, attention_mask, joint_attention_kwargs, ) elif not is_skip: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, attention_mask=attention_mask, joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual if block_controlnet_hidden_states is not None and block.context_pre_only is False: interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states) hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) patch_size = self.config.patch_size latent_sizes = [(latent_size[0] // patch_size, latent_size[1] // patch_size) for latent_size in latent_sizes] # import pdb; pdb.set_trace() # unpatchify output = [rearrange(hs[:math.prod(latent_size)], '(h w) (p q c) -> c (h p) (w q)', h=latent_size[0], w=latent_size[1], p=patch_size, q=patch_size) for hs, latent_size in zip(hidden_states, latent_sizes)] try: output = torch.stack(output) # can be staked if all have the save shape except: # cannot be stacked pass if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output)