Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import math | |
| from typing import Any, Dict, List, Optional | |
| import torch.nn as nn | |
| import torch as th | |
| import torch | |
| from modules.Utilities import util | |
| from modules.AutoEncoders import ResBlock | |
| from modules.NeuralNetwork import transformer | |
| from modules.cond import cast | |
| from modules.sample import sampling, sampling_util | |
| UNET_MAP_ATTENTIONS = { | |
| "proj_in.weight", | |
| "proj_in.bias", | |
| "proj_out.weight", | |
| "proj_out.bias", | |
| "norm.weight", | |
| "norm.bias", | |
| } | |
| TRANSFORMER_BLOCKS = { | |
| "norm1.weight", | |
| "norm1.bias", | |
| "norm2.weight", | |
| "norm2.bias", | |
| "norm3.weight", | |
| "norm3.bias", | |
| "attn1.to_q.weight", | |
| "attn1.to_k.weight", | |
| "attn1.to_v.weight", | |
| "attn1.to_out.0.weight", | |
| "attn1.to_out.0.bias", | |
| "attn2.to_q.weight", | |
| "attn2.to_k.weight", | |
| "attn2.to_v.weight", | |
| "attn2.to_out.0.weight", | |
| "attn2.to_out.0.bias", | |
| "ff.net.0.proj.weight", | |
| "ff.net.0.proj.bias", | |
| "ff.net.2.weight", | |
| "ff.net.2.bias", | |
| } | |
| UNET_MAP_RESNET = { | |
| "in_layers.2.weight": "conv1.weight", | |
| "in_layers.2.bias": "conv1.bias", | |
| "emb_layers.1.weight": "time_emb_proj.weight", | |
| "emb_layers.1.bias": "time_emb_proj.bias", | |
| "out_layers.3.weight": "conv2.weight", | |
| "out_layers.3.bias": "conv2.bias", | |
| "skip_connection.weight": "conv_shortcut.weight", | |
| "skip_connection.bias": "conv_shortcut.bias", | |
| "in_layers.0.weight": "norm1.weight", | |
| "in_layers.0.bias": "norm1.bias", | |
| "out_layers.0.weight": "norm2.weight", | |
| "out_layers.0.bias": "norm2.bias", | |
| } | |
| UNET_MAP_BASIC = { | |
| ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), | |
| ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), | |
| ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), | |
| ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), | |
| ("label_emb.0.0.weight", "add_embedding.linear_1.weight"), | |
| ("label_emb.0.0.bias", "add_embedding.linear_1.bias"), | |
| ("label_emb.0.2.weight", "add_embedding.linear_2.weight"), | |
| ("label_emb.0.2.bias", "add_embedding.linear_2.bias"), | |
| ("input_blocks.0.0.weight", "conv_in.weight"), | |
| ("input_blocks.0.0.bias", "conv_in.bias"), | |
| ("out.0.weight", "conv_norm_out.weight"), | |
| ("out.0.bias", "conv_norm_out.bias"), | |
| ("out.2.weight", "conv_out.weight"), | |
| ("out.2.bias", "conv_out.bias"), | |
| ("time_embed.0.weight", "time_embedding.linear_1.weight"), | |
| ("time_embed.0.bias", "time_embedding.linear_1.bias"), | |
| ("time_embed.2.weight", "time_embedding.linear_2.weight"), | |
| ("time_embed.2.bias", "time_embedding.linear_2.bias"), | |
| } | |
| # taken from https://github.com/TencentARC/T2I-Adapter | |
| def unet_to_diffusers(unet_config: dict) -> dict: | |
| """#### Convert a UNet configuration to a diffusers configuration. | |
| #### Args: | |
| - `unet_config` (dict): The UNet configuration. | |
| #### Returns: | |
| - `dict`: The diffusers configuration. | |
| """ | |
| if "num_res_blocks" not in unet_config: | |
| return {} | |
| num_res_blocks = unet_config["num_res_blocks"] | |
| channel_mult = unet_config["channel_mult"] | |
| transformer_depth = unet_config["transformer_depth"][:] | |
| transformer_depth_output = unet_config["transformer_depth_output"][:] | |
| num_blocks = len(channel_mult) | |
| transformers_mid = unet_config.get("transformer_depth_middle", None) | |
| diffusers_unet_map = {} | |
| for x in range(num_blocks): | |
| n = 1 + (num_res_blocks[x] + 1) * x | |
| for i in range(num_res_blocks[x]): | |
| for b in UNET_MAP_RESNET: | |
| diffusers_unet_map[ | |
| "down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b]) | |
| ] = "input_blocks.{}.0.{}".format(n, b) | |
| num_transformers = transformer_depth.pop(0) | |
| if num_transformers > 0: | |
| for b in UNET_MAP_ATTENTIONS: | |
| diffusers_unet_map[ | |
| "down_blocks.{}.attentions.{}.{}".format(x, i, b) | |
| ] = "input_blocks.{}.1.{}".format(n, b) | |
| for t in range(num_transformers): | |
| for b in TRANSFORMER_BLOCKS: | |
| diffusers_unet_map[ | |
| "down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format( | |
| x, i, t, b | |
| ) | |
| ] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) | |
| n += 1 | |
| for k in ["weight", "bias"]: | |
| diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = ( | |
| "input_blocks.{}.0.op.{}".format(n, k) | |
| ) | |
| i = 0 | |
| for b in UNET_MAP_ATTENTIONS: | |
| diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = ( | |
| "middle_block.1.{}".format(b) | |
| ) | |
| for t in range(transformers_mid): | |
| for b in TRANSFORMER_BLOCKS: | |
| diffusers_unet_map[ | |
| "mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b) | |
| ] = "middle_block.1.transformer_blocks.{}.{}".format(t, b) | |
| for i, n in enumerate([0, 2]): | |
| for b in UNET_MAP_RESNET: | |
| diffusers_unet_map[ | |
| "mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b]) | |
| ] = "middle_block.{}.{}".format(n, b) | |
| num_res_blocks = list(reversed(num_res_blocks)) | |
| for x in range(num_blocks): | |
| n = (num_res_blocks[x] + 1) * x | |
| length = num_res_blocks[x] + 1 | |
| for i in range(length): | |
| c = 0 | |
| for b in UNET_MAP_RESNET: | |
| diffusers_unet_map[ | |
| "up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b]) | |
| ] = "output_blocks.{}.0.{}".format(n, b) | |
| c += 1 | |
| num_transformers = transformer_depth_output.pop() | |
| if num_transformers > 0: | |
| c += 1 | |
| for b in UNET_MAP_ATTENTIONS: | |
| diffusers_unet_map[ | |
| "up_blocks.{}.attentions.{}.{}".format(x, i, b) | |
| ] = "output_blocks.{}.1.{}".format(n, b) | |
| for t in range(num_transformers): | |
| for b in TRANSFORMER_BLOCKS: | |
| diffusers_unet_map[ | |
| "up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format( | |
| x, i, t, b | |
| ) | |
| ] = "output_blocks.{}.1.transformer_blocks.{}.{}".format( | |
| n, t, b | |
| ) | |
| if i == length - 1: | |
| for k in ["weight", "bias"]: | |
| diffusers_unet_map[ | |
| "up_blocks.{}.upsamplers.0.conv.{}".format(x, k) | |
| ] = "output_blocks.{}.{}.conv.{}".format(n, c, k) | |
| n += 1 | |
| for k in UNET_MAP_BASIC: | |
| diffusers_unet_map[k[1]] = k[0] | |
| return diffusers_unet_map | |
| def apply_control1(h: th.Tensor, control: any, name: str) -> th.Tensor: | |
| """#### Apply control to a tensor. | |
| #### Args: | |
| - `h` (torch.Tensor): The input tensor. | |
| - `control` (any): The control to apply. | |
| - `name` (str): The name of the control. | |
| #### Returns: | |
| - `torch.Tensor`: The controlled tensor. | |
| """ | |
| return h | |
| oai_ops = cast.disable_weight_init | |
| class UNetModel1(nn.Module): | |
| """#### UNet Model class.""" | |
| def __init__( | |
| self, | |
| image_size: int, | |
| in_channels: int, | |
| model_channels: int, | |
| out_channels: int, | |
| num_res_blocks: list, | |
| dropout: float = 0, | |
| channel_mult: tuple = (1, 2, 4, 8), | |
| conv_resample: bool = True, | |
| dims: int = 2, | |
| num_classes: int = None, | |
| use_checkpoint: bool = False, | |
| dtype: th.dtype = th.float32, | |
| num_heads: int = -1, | |
| num_head_channels: int = -1, | |
| num_heads_upsample: int = -1, | |
| use_scale_shift_norm: bool = False, | |
| resblock_updown: bool = False, | |
| use_new_attention_order: bool = False, | |
| use_spatial_transformer: bool = False, # custom transformer support | |
| transformer_depth: int = 1, # custom transformer support | |
| context_dim: int = None, # custom transformer support | |
| n_embed: int = None, # custom support for prediction of discrete ids into codebook of first stage vq model | |
| legacy: bool = True, | |
| disable_self_attentions: list = None, | |
| num_attention_blocks: list = None, | |
| disable_middle_self_attn: bool = False, | |
| use_linear_in_transformer: bool = False, | |
| adm_in_channels: int = None, | |
| transformer_depth_middle: int = None, | |
| transformer_depth_output: list = None, | |
| use_temporal_resblock: bool = False, | |
| use_temporal_attention: bool = False, | |
| time_context_dim: int = None, | |
| extra_ff_mix_layer: bool = False, | |
| use_spatial_context: bool = False, | |
| merge_strategy: any = None, | |
| merge_factor: float = 0.0, | |
| video_kernel_size: int = None, | |
| disable_temporal_crossattention: bool = False, | |
| max_ddpm_temb_period: int = 10000, | |
| device: th.device = None, | |
| operations: any = oai_ops, | |
| ): | |
| """#### Initialize the UNetModel1 class. | |
| #### Args: | |
| - `image_size` (int): The size of the input image. | |
| - `in_channels` (int): The number of input channels. | |
| - `model_channels` (int): The number of model channels. | |
| - `out_channels` (int): The number of output channels. | |
| - `num_res_blocks` (list): The number of residual blocks. | |
| - `dropout` (float, optional): The dropout rate. Defaults to 0. | |
| - `channel_mult` (tuple, optional): The channel multiplier. Defaults to (1, 2, 4, 8). | |
| - `conv_resample` (bool, optional): Whether to use convolutional resampling. Defaults to True. | |
| - `dims` (int, optional): The number of dimensions. Defaults to 2. | |
| - `num_classes` (int, optional): The number of classes. Defaults to None. | |
| - `use_checkpoint` (bool, optional): Whether to use checkpointing. Defaults to False. | |
| - `dtype` (torch.dtype, optional): The data type. Defaults to torch.float32. | |
| - `num_heads` (int, optional): The number of heads. Defaults to -1. | |
| - `num_head_channels` (int, optional): The number of head channels. Defaults to -1. | |
| - `num_heads_upsample` (int, optional): The number of heads for upsampling. Defaults to -1. | |
| - `use_scale_shift_norm` (bool, optional): Whether to use scale-shift normalization. Defaults to False. | |
| - `resblock_updown` (bool, optional): Whether to use residual blocks for up/down sampling. Defaults to False. | |
| - `use_new_attention_order` (bool, optional): Whether to use a new attention order. Defaults to False. | |
| - `use_spatial_transformer` (bool, optional): Whether to use a spatial transformer. Defaults to False. | |
| - `transformer_depth` (int, optional): The depth of the transformer. Defaults to 1. | |
| - `context_dim` (int, optional): The context dimension. Defaults to None. | |
| - `n_embed` (int, optional): The number of embeddings. Defaults to None. | |
| - `legacy` (bool, optional): Whether to use legacy mode. Defaults to True. | |
| - `disable_self_attentions` (list, optional): The list of self-attentions to disable. Defaults to None. | |
| - `num_attention_blocks` (list, optional): The number of attention blocks. Defaults to None. | |
| - `disable_middle_self_attn` (bool, optional): Whether to disable middle self-attention. Defaults to False. | |
| - `use_linear_in_transformer` (bool, optional): Whether to use linear in transformer. Defaults to False. | |
| - `adm_in_channels` (int, optional): The number of ADM input channels. Defaults to None. | |
| - `transformer_depth_middle` (int, optional): The depth of the middle transformer. Defaults to None. | |
| - `transformer_depth_output` (list, optional): The depth of the output transformer. Defaults to None. | |
| - `use_temporal_resblock` (bool, optional): Whether to use temporal residual blocks. Defaults to False. | |
| - `use_temporal_attention` (bool, optional): Whether to use temporal attention. Defaults to False. | |
| - `time_context_dim` (int, optional): The time context dimension. Defaults to None. | |
| - `extra_ff_mix_layer` (bool, optional): Whether to use an extra feed-forward mix layer. Defaults to False. | |
| - `use_spatial_context` (bool, optional): Whether to use spatial context. Defaults to False. | |
| - `merge_strategy` (any, optional): The merge strategy. Defaults to None. | |
| - `merge_factor` (float, optional): The merge factor. Defaults to 0.0. | |
| - `video_kernel_size` (int, optional): The video kernel size. Defaults to None. | |
| - `disable_temporal_crossattention` (bool, optional): Whether to disable temporal cross-attention. Defaults to False. | |
| - `max_ddpm_temb_period` (int, optional): The maximum DDPM temporal embedding period. Defaults to 10000. | |
| - `device` (torch.device, optional): The device to use. Defaults to None. | |
| - `operations` (any, optional): The operations to use. Defaults to oai_ops. | |
| """ | |
| super().__init__() | |
| if context_dim is not None: | |
| self.context_dim = context_dim | |
| if num_heads_upsample == -1: | |
| num_heads_upsample = num_heads | |
| if num_head_channels == -1: | |
| assert num_heads != -1, ( | |
| "Either num_heads or num_head_channels has to be set" | |
| ) | |
| self.in_channels = in_channels | |
| self.model_channels = model_channels | |
| self.out_channels = out_channels | |
| self.num_res_blocks = num_res_blocks | |
| transformer_depth = transformer_depth[:] | |
| transformer_depth_output = transformer_depth_output[:] | |
| self.dropout = dropout | |
| self.channel_mult = channel_mult | |
| self.conv_resample = conv_resample | |
| self.num_classes = num_classes | |
| self.use_checkpoint = use_checkpoint | |
| self.dtype = dtype | |
| self.num_heads = num_heads | |
| self.num_head_channels = num_head_channels | |
| self.num_heads_upsample = num_heads_upsample | |
| self.use_temporal_resblocks = use_temporal_resblock | |
| self.predict_codebook_ids = n_embed is not None | |
| self.default_num_video_frames = None | |
| time_embed_dim = model_channels * 4 | |
| self.time_embed = nn.Sequential( | |
| operations.Linear( | |
| model_channels, time_embed_dim, dtype=self.dtype, device=device | |
| ), | |
| nn.SiLU(), | |
| operations.Linear( | |
| time_embed_dim, time_embed_dim, dtype=self.dtype, device=device | |
| ), | |
| ) | |
| self.input_blocks = nn.ModuleList( | |
| [ | |
| sampling.TimestepEmbedSequential1( | |
| operations.conv_nd( | |
| dims, | |
| in_channels, | |
| model_channels, | |
| 3, | |
| padding=1, | |
| dtype=self.dtype, | |
| device=device, | |
| ) | |
| ) | |
| ] | |
| ) | |
| self._feature_size = model_channels | |
| input_block_chans = [model_channels] | |
| ch = model_channels | |
| ds = 1 | |
| def get_attention_layer( | |
| ch: int, | |
| num_heads: int, | |
| dim_head: int, | |
| depth: int = 1, | |
| context_dim: int = None, | |
| use_checkpoint: bool = False, | |
| disable_self_attn: bool = False, | |
| ) -> transformer.SpatialTransformer: | |
| """#### Get an attention layer. | |
| #### Args: | |
| - `ch` (int): The number of channels. | |
| - `num_heads` (int): The number of heads. | |
| - `dim_head` (int): The dimension of each head. | |
| - `depth` (int, optional): The depth of the transformer. Defaults to 1. | |
| - `context_dim` (int, optional): The context dimension. Defaults to None. | |
| - `use_checkpoint` (bool, optional): Whether to use checkpointing. Defaults to False. | |
| - `disable_self_attn` (bool, optional): Whether to disable self-attention. Defaults to False. | |
| #### Returns: | |
| - `transformer.SpatialTransformer`: The attention layer. | |
| """ | |
| return transformer.SpatialTransformer( | |
| ch, | |
| num_heads, | |
| dim_head, | |
| depth=depth, | |
| context_dim=context_dim, | |
| disable_self_attn=disable_self_attn, | |
| use_linear=use_linear_in_transformer, | |
| use_checkpoint=use_checkpoint, | |
| dtype=self.dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| def get_resblock( | |
| merge_factor: float, | |
| merge_strategy: any, | |
| video_kernel_size: int, | |
| ch: int, | |
| time_embed_dim: int, | |
| dropout: float, | |
| out_channels: int, | |
| dims: int, | |
| use_checkpoint: bool, | |
| use_scale_shift_norm: bool, | |
| down: bool = False, | |
| up: bool = False, | |
| dtype: th.dtype = None, | |
| device: th.device = None, | |
| operations: any = oai_ops, | |
| ) -> ResBlock.ResBlock1: | |
| """#### Get a residual block. | |
| #### Args: | |
| - `merge_factor` (float): The merge factor. | |
| - `merge_strategy` (any): The merge strategy. | |
| - `video_kernel_size` (int): The video kernel size. | |
| - `ch` (int): The number of channels. | |
| - `time_embed_dim` (int): The time embedding dimension. | |
| - `dropout` (float): The dropout rate. | |
| - `out_channels` (int): The number of output channels. | |
| - `dims` (int): The number of dimensions. | |
| - `use_checkpoint` (bool): Whether to use checkpointing. | |
| - `use_scale_shift_norm` (bool): Whether to use scale-shift normalization. | |
| - `down` (bool, optional): Whether to use downsampling. Defaults to False. | |
| - `up` (bool, optional): Whether to use upsampling. Defaults to False. | |
| - `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
| - `device` (torch.device, optional): The device. Defaults to None. | |
| - `operations` (any, optional): The operations to use. Defaults to oai_ops. | |
| #### Returns: | |
| - `ResBlock.ResBlock1`: The residual block. | |
| """ | |
| return ResBlock.ResBlock1( | |
| channels=ch, | |
| emb_channels=time_embed_dim, | |
| dropout=dropout, | |
| out_channels=out_channels, | |
| use_checkpoint=use_checkpoint, | |
| dims=dims, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| down=down, | |
| up=up, | |
| dtype=dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| self.double_blocks = nn.ModuleList() | |
| for level, mult in enumerate(channel_mult): | |
| for nr in range(self.num_res_blocks[level]): | |
| layers = [ | |
| get_resblock( | |
| merge_factor=merge_factor, | |
| merge_strategy=merge_strategy, | |
| video_kernel_size=video_kernel_size, | |
| ch=ch, | |
| time_embed_dim=time_embed_dim, | |
| dropout=dropout, | |
| out_channels=mult * model_channels, | |
| dims=dims, | |
| use_checkpoint=use_checkpoint, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| dtype=self.dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| ] | |
| ch = mult * model_channels | |
| num_transformers = transformer_depth.pop(0) | |
| if num_transformers > 0: | |
| dim_head = ch // num_heads | |
| disabled_sa = False | |
| if ( | |
| not util.exists(num_attention_blocks) | |
| or nr < num_attention_blocks[level] | |
| ): | |
| layers.append( | |
| get_attention_layer( | |
| ch, | |
| num_heads, | |
| dim_head, | |
| depth=num_transformers, | |
| context_dim=context_dim, | |
| disable_self_attn=disabled_sa, | |
| use_checkpoint=use_checkpoint, | |
| ) | |
| ) | |
| self.input_blocks.append(sampling.TimestepEmbedSequential1(*layers)) | |
| self._feature_size += ch | |
| input_block_chans.append(ch) | |
| if level != len(channel_mult) - 1: | |
| out_ch = ch | |
| self.input_blocks.append( | |
| sampling.TimestepEmbedSequential1( | |
| get_resblock( | |
| merge_factor=merge_factor, | |
| merge_strategy=merge_strategy, | |
| video_kernel_size=video_kernel_size, | |
| ch=ch, | |
| time_embed_dim=time_embed_dim, | |
| dropout=dropout, | |
| out_channels=out_ch, | |
| dims=dims, | |
| use_checkpoint=use_checkpoint, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| down=True, | |
| dtype=self.dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| if resblock_updown | |
| else ResBlock.Downsample1( | |
| ch, | |
| conv_resample, | |
| dims=dims, | |
| out_channels=out_ch, | |
| dtype=self.dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| ) | |
| ) | |
| ch = out_ch | |
| input_block_chans.append(ch) | |
| ds *= 2 | |
| self._feature_size += ch | |
| dim_head = ch // num_heads | |
| mid_block = [ | |
| get_resblock( | |
| merge_factor=merge_factor, | |
| merge_strategy=merge_strategy, | |
| video_kernel_size=video_kernel_size, | |
| ch=ch, | |
| time_embed_dim=time_embed_dim, | |
| dropout=dropout, | |
| out_channels=None, | |
| dims=dims, | |
| use_checkpoint=use_checkpoint, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| dtype=self.dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| ] | |
| self.middle_block = None | |
| if transformer_depth_middle >= -1: | |
| if transformer_depth_middle >= 0: | |
| mid_block += [ | |
| get_attention_layer( # always uses a self-attn | |
| ch, | |
| num_heads, | |
| dim_head, | |
| depth=transformer_depth_middle, | |
| context_dim=context_dim, | |
| disable_self_attn=disable_middle_self_attn, | |
| use_checkpoint=use_checkpoint, | |
| ), | |
| get_resblock( | |
| merge_factor=merge_factor, | |
| merge_strategy=merge_strategy, | |
| video_kernel_size=video_kernel_size, | |
| ch=ch, | |
| time_embed_dim=time_embed_dim, | |
| dropout=dropout, | |
| out_channels=None, | |
| dims=dims, | |
| use_checkpoint=use_checkpoint, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| dtype=self.dtype, | |
| device=device, | |
| operations=operations, | |
| ), | |
| ] | |
| self.middle_block = sampling.TimestepEmbedSequential1(*mid_block) | |
| self._feature_size += ch | |
| self.output_blocks = nn.ModuleList([]) | |
| for level, mult in list(enumerate(channel_mult))[::-1]: | |
| for i in range(self.num_res_blocks[level] + 1): | |
| ich = input_block_chans.pop() | |
| layers = [ | |
| get_resblock( | |
| merge_factor=merge_factor, | |
| merge_strategy=merge_strategy, | |
| video_kernel_size=video_kernel_size, | |
| ch=ch + ich, | |
| time_embed_dim=time_embed_dim, | |
| dropout=dropout, | |
| out_channels=model_channels * mult, | |
| dims=dims, | |
| use_checkpoint=use_checkpoint, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| dtype=self.dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| ] | |
| ch = model_channels * mult | |
| num_transformers = transformer_depth_output.pop() | |
| if num_transformers > 0: | |
| dim_head = ch // num_heads | |
| disabled_sa = False | |
| if ( | |
| not util.exists(num_attention_blocks) | |
| or i < num_attention_blocks[level] | |
| ): | |
| layers.append( | |
| get_attention_layer( | |
| ch, | |
| num_heads, | |
| dim_head, | |
| depth=num_transformers, | |
| context_dim=context_dim, | |
| disable_self_attn=disabled_sa, | |
| use_checkpoint=use_checkpoint, | |
| ) | |
| ) | |
| if level and i == self.num_res_blocks[level]: | |
| out_ch = ch | |
| layers.append( | |
| get_resblock( | |
| merge_factor=merge_factor, | |
| merge_strategy=merge_strategy, | |
| video_kernel_size=video_kernel_size, | |
| ch=ch, | |
| time_embed_dim=time_embed_dim, | |
| dropout=dropout, | |
| out_channels=out_ch, | |
| dims=dims, | |
| use_checkpoint=use_checkpoint, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| up=True, | |
| dtype=self.dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| if resblock_updown | |
| else ResBlock.Upsample1( | |
| ch, | |
| conv_resample, | |
| dims=dims, | |
| out_channels=out_ch, | |
| dtype=self.dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| ) | |
| ds //= 2 | |
| self.output_blocks.append(sampling.TimestepEmbedSequential1(*layers)) | |
| self._feature_size += ch | |
| self.out = nn.Sequential( | |
| operations.GroupNorm(32, ch, dtype=self.dtype, device=device), | |
| nn.SiLU(), | |
| util.zero_module( | |
| operations.conv_nd( | |
| dims, | |
| model_channels, | |
| out_channels, | |
| 3, | |
| padding=1, | |
| dtype=self.dtype, | |
| device=device, | |
| ) | |
| ), | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| timesteps: Optional[torch.Tensor] = None, | |
| context: Optional[torch.Tensor] = None, | |
| y: Optional[torch.Tensor] = None, | |
| control: Optional[torch.Tensor] = None, | |
| transformer_options: Dict[str, Any] = {}, | |
| **kwargs: Any, | |
| ) -> torch.Tensor: | |
| """#### Forward pass of the UNet model with optimized calculations.""" | |
| # Setup transformer options (avoid unused variable) | |
| transformer_options["original_shape"] = list(x.shape) | |
| transformer_options["transformer_index"] = 0 | |
| # Extract kwargs efficiently | |
| num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) | |
| image_only_indicator = kwargs.get("image_only_indicator", None) | |
| time_context = kwargs.get("time_context", None) | |
| # Validation | |
| assert (y is not None) == (self.num_classes is not None), ( | |
| "must specify y if and only if the model is class-conditional" | |
| ) | |
| # Time embedding - optimize by computing with target dtype directly | |
| t_emb = sampling_util.timestep_embedding(timesteps, self.model_channels).to( | |
| x.dtype | |
| ) | |
| emb = self.time_embed(t_emb) | |
| # Input blocks processing | |
| hs = [] | |
| h = x | |
| for id, module in enumerate(self.input_blocks): | |
| transformer_options["block"] = ("input", id) | |
| h = ResBlock.forward_timestep_embed1( | |
| module, | |
| h, | |
| emb, | |
| context, | |
| transformer_options, | |
| time_context=time_context, | |
| num_video_frames=num_video_frames, | |
| image_only_indicator=image_only_indicator, | |
| ) | |
| h = apply_control1(h, control, "input") | |
| hs.append(h) | |
| # Middle block processing | |
| transformer_options["block"] = ("middle", 0) | |
| if self.middle_block is not None: | |
| h = ResBlock.forward_timestep_embed1( | |
| self.middle_block, | |
| h, | |
| emb, | |
| context, | |
| transformer_options, | |
| time_context=time_context, | |
| num_video_frames=num_video_frames, | |
| image_only_indicator=image_only_indicator, | |
| ) | |
| h = apply_control1(h, control, "middle") | |
| # Output blocks processing - optimize memory usage | |
| for id, module in enumerate(self.output_blocks): | |
| transformer_options["block"] = ("output", id) | |
| hsp = hs.pop() | |
| hsp = apply_control1(hsp, control, "output") | |
| # Concatenate tensors | |
| h = torch.cat([h, hsp], dim=1) | |
| del hsp # Free memory immediately | |
| # Only calculate output shape when needed | |
| output_shape = hs[-1].shape if hs else None | |
| h = ResBlock.forward_timestep_embed1( | |
| module, | |
| h, | |
| emb, | |
| context, | |
| transformer_options, | |
| output_shape, | |
| time_context=time_context, | |
| num_video_frames=num_video_frames, | |
| image_only_indicator=image_only_indicator, | |
| ) | |
| # Ensure output has correct dtype | |
| h = h.type(x.dtype) | |
| return self.out(h) | |
| def detect_unet_config( | |
| state_dict: Dict[str, torch.Tensor], key_prefix: str | |
| ) -> Dict[str, Any]: | |
| """#### Detect the UNet configuration from a state dictionary. | |
| #### Args: | |
| - `state_dict` (Dict[str, torch.Tensor]): The state dictionary. | |
| - `key_prefix` (str): The key prefix. | |
| #### Returns: | |
| - `Dict[str, Any]`: The detected UNet configuration. | |
| """ | |
| state_dict_keys = list(state_dict.keys()) | |
| if ( | |
| "{}joint_blocks.0.context_block.attn.qkv.weight".format(key_prefix) | |
| in state_dict_keys | |
| ): # mmdit model | |
| unet_config = {} | |
| unet_config["in_channels"] = state_dict[ | |
| "{}x_embedder.proj.weight".format(key_prefix) | |
| ].shape[1] | |
| patch_size = state_dict["{}x_embedder.proj.weight".format(key_prefix)].shape[2] | |
| unet_config["patch_size"] = patch_size | |
| final_layer = "{}final_layer.linear.weight".format(key_prefix) | |
| if final_layer in state_dict: | |
| unet_config["out_channels"] = state_dict[final_layer].shape[0] // ( | |
| patch_size * patch_size | |
| ) | |
| unet_config["depth"] = ( | |
| state_dict["{}x_embedder.proj.weight".format(key_prefix)].shape[0] // 64 | |
| ) | |
| unet_config["input_size"] = None | |
| y_key = "{}y_embedder.mlp.0.weight".format(key_prefix) | |
| if y_key in state_dict_keys: | |
| unet_config["adm_in_channels"] = state_dict[y_key].shape[1] | |
| context_key = "{}context_embedder.weight".format(key_prefix) | |
| if context_key in state_dict_keys: | |
| in_features = state_dict[context_key].shape[1] | |
| out_features = state_dict[context_key].shape[0] | |
| unet_config["context_embedder_config"] = { | |
| "target": "torch.nn.Linear", | |
| "params": {"in_features": in_features, "out_features": out_features}, | |
| } | |
| num_patches_key = "{}pos_embed".format(key_prefix) | |
| if num_patches_key in state_dict_keys: | |
| num_patches = state_dict[num_patches_key].shape[1] | |
| unet_config["num_patches"] = num_patches | |
| unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches)) | |
| rms_qk = "{}joint_blocks.0.context_block.attn.ln_q.weight".format(key_prefix) | |
| if rms_qk in state_dict_keys: | |
| unet_config["qk_norm"] = "rms" | |
| unet_config["pos_embed_scaling_factor"] = None # unused for inference | |
| context_processor = "{}context_processor.layers.0.attn.qkv.weight".format( | |
| key_prefix | |
| ) | |
| if context_processor in state_dict_keys: | |
| unet_config["context_processor_layers"] = transformer.count_blocks( | |
| state_dict_keys, | |
| "{}context_processor.layers.".format(key_prefix) + "{}.", | |
| ) | |
| return unet_config | |
| if "{}clf.1.weight".format(key_prefix) in state_dict_keys: # stable cascade | |
| unet_config = {} | |
| text_mapper_name = "{}clip_txt_mapper.weight".format(key_prefix) | |
| if text_mapper_name in state_dict_keys: | |
| unet_config["stable_cascade_stage"] = "c" | |
| w = state_dict[text_mapper_name] | |
| if w.shape[0] == 1536: # stage c lite | |
| unet_config["c_cond"] = 1536 | |
| unet_config["c_hidden"] = [1536, 1536] | |
| unet_config["nhead"] = [24, 24] | |
| unet_config["blocks"] = [[4, 12], [12, 4]] | |
| elif w.shape[0] == 2048: # stage c full | |
| unet_config["c_cond"] = 2048 | |
| elif "{}clip_mapper.weight".format(key_prefix) in state_dict_keys: | |
| unet_config["stable_cascade_stage"] = "b" | |
| w = state_dict["{}down_blocks.1.0.channelwise.0.weight".format(key_prefix)] | |
| if w.shape[-1] == 640: | |
| unet_config["c_hidden"] = [320, 640, 1280, 1280] | |
| unet_config["nhead"] = [-1, -1, 20, 20] | |
| unet_config["blocks"] = [[2, 6, 28, 6], [6, 28, 6, 2]] | |
| unet_config["block_repeat"] = [[1, 1, 1, 1], [3, 3, 2, 2]] | |
| elif w.shape[-1] == 576: # stage b lite | |
| unet_config["c_hidden"] = [320, 576, 1152, 1152] | |
| unet_config["nhead"] = [-1, 9, 18, 18] | |
| unet_config["blocks"] = [[2, 4, 14, 4], [4, 14, 4, 2]] | |
| unet_config["block_repeat"] = [[1, 1, 1, 1], [2, 2, 2, 2]] | |
| return unet_config | |
| if ( | |
| "{}transformer.rotary_pos_emb.inv_freq".format(key_prefix) in state_dict_keys | |
| ): # stable audio dit | |
| unet_config = {} | |
| unet_config["audio_model"] = "dit1.0" | |
| return unet_config | |
| if ( | |
| "{}double_layers.0.attn.w1q.weight".format(key_prefix) in state_dict_keys | |
| ): # aura flow dit | |
| unet_config = {} | |
| unet_config["max_seq"] = state_dict[ | |
| "{}positional_encoding".format(key_prefix) | |
| ].shape[1] | |
| unet_config["cond_seq_dim"] = state_dict[ | |
| "{}cond_seq_linear.weight".format(key_prefix) | |
| ].shape[1] | |
| double_layers = transformer.count_blocks( | |
| state_dict_keys, "{}double_layers.".format(key_prefix) + "{}." | |
| ) | |
| single_layers = transformer.count_blocks( | |
| state_dict_keys, "{}single_layers.".format(key_prefix) + "{}." | |
| ) | |
| unet_config["n_double_layers"] = double_layers | |
| unet_config["n_layers"] = double_layers + single_layers | |
| return unet_config | |
| if "{}mlp_t5.0.weight".format(key_prefix) in state_dict_keys: # Hunyuan DiT | |
| unet_config = {} | |
| unet_config["image_model"] = "hydit" | |
| unet_config["depth"] = transformer.count_blocks( | |
| state_dict_keys, "{}blocks.".format(key_prefix) + "{}." | |
| ) | |
| unet_config["hidden_size"] = state_dict[ | |
| "{}x_embedder.proj.weight".format(key_prefix) | |
| ].shape[0] | |
| if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: # DiT-g/2 | |
| unet_config["mlp_ratio"] = 4.3637 | |
| if state_dict["{}extra_embedder.0.weight".format(key_prefix)].shape[1] == 3968: | |
| unet_config["size_cond"] = True | |
| unet_config["use_style_cond"] = True | |
| unet_config["image_model"] = "hydit1" | |
| return unet_config | |
| if ( | |
| "{}double_blocks.0.img_attn.norm.key_norm.scale".format(key_prefix) | |
| in state_dict_keys | |
| ): # Flux | |
| dit_config = {} | |
| dit_config["image_model"] = "flux" | |
| dit_config["in_channels"] = 16 | |
| dit_config["vec_in_dim"] = 768 | |
| dit_config["context_in_dim"] = 4096 | |
| dit_config["hidden_size"] = 3072 | |
| dit_config["mlp_ratio"] = 4.0 | |
| dit_config["num_heads"] = 24 | |
| dit_config["depth"] = transformer.count_blocks( | |
| state_dict_keys, "{}double_blocks.".format(key_prefix) + "{}." | |
| ) | |
| dit_config["depth_single_blocks"] = transformer.count_blocks( | |
| state_dict_keys, "{}single_blocks.".format(key_prefix) + "{}." | |
| ) | |
| dit_config["axes_dim"] = [16, 56, 56] | |
| dit_config["theta"] = 10000 | |
| dit_config["qkv_bias"] = True | |
| dit_config["guidance_embed"] = ( | |
| "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys | |
| ) | |
| return dit_config | |
| if "{}input_blocks.0.0.weight".format(key_prefix) not in state_dict_keys: | |
| return None | |
| unet_config = { | |
| "use_checkpoint": False, | |
| "image_size": 32, | |
| "use_spatial_transformer": True, | |
| "legacy": False, | |
| } | |
| y_input = "{}label_emb.0.0.weight".format(key_prefix) | |
| if y_input in state_dict_keys: | |
| unet_config["num_classes"] = "sequential" | |
| unet_config["adm_in_channels"] = state_dict[y_input].shape[1] | |
| else: | |
| unet_config["adm_in_channels"] = None | |
| model_channels = state_dict["{}input_blocks.0.0.weight".format(key_prefix)].shape[0] | |
| in_channels = state_dict["{}input_blocks.0.0.weight".format(key_prefix)].shape[1] | |
| out_key = "{}out.2.weight".format(key_prefix) | |
| if out_key in state_dict: | |
| out_channels = state_dict[out_key].shape[0] | |
| else: | |
| out_channels = 4 | |
| num_res_blocks = [] | |
| channel_mult = [] | |
| transformer_depth = [] | |
| transformer_depth_output = [] | |
| context_dim = None | |
| use_linear_in_transformer = False | |
| video_model = False | |
| video_model_cross = False | |
| current_res = 1 | |
| count = 0 | |
| last_res_blocks = 0 | |
| last_channel_mult = 0 | |
| input_block_count = transformer.count_blocks( | |
| state_dict_keys, "{}input_blocks".format(key_prefix) + ".{}." | |
| ) | |
| for count in range(input_block_count): | |
| prefix = "{}input_blocks.{}.".format(key_prefix, count) | |
| prefix_output = "{}output_blocks.{}.".format( | |
| key_prefix, input_block_count - count - 1 | |
| ) | |
| block_keys = sorted( | |
| list(filter(lambda a: a.startswith(prefix), state_dict_keys)) | |
| ) | |
| if len(block_keys) == 0: | |
| break | |
| block_keys_output = sorted( | |
| list(filter(lambda a: a.startswith(prefix_output), state_dict_keys)) | |
| ) | |
| if "{}0.op.weight".format(prefix) in block_keys: # new layer | |
| num_res_blocks.append(last_res_blocks) | |
| channel_mult.append(last_channel_mult) | |
| current_res *= 2 | |
| last_res_blocks = 0 | |
| last_channel_mult = 0 | |
| out = transformer.calculate_transformer_depth( | |
| prefix_output, state_dict_keys, state_dict | |
| ) | |
| if out is not None: | |
| transformer_depth_output.append(out[0]) | |
| else: | |
| transformer_depth_output.append(0) | |
| else: | |
| res_block_prefix = "{}0.in_layers.0.weight".format(prefix) | |
| if res_block_prefix in block_keys: | |
| last_res_blocks += 1 | |
| last_channel_mult = ( | |
| state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] | |
| // model_channels | |
| ) | |
| out = transformer.calculate_transformer_depth( | |
| prefix, state_dict_keys, state_dict | |
| ) | |
| if out is not None: | |
| transformer_depth.append(out[0]) | |
| if context_dim is None: | |
| context_dim = out[1] | |
| use_linear_in_transformer = out[2] | |
| out[3] | |
| else: | |
| transformer_depth.append(0) | |
| res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output) | |
| if res_block_prefix in block_keys_output: | |
| out = transformer.calculate_transformer_depth( | |
| prefix_output, state_dict_keys, state_dict | |
| ) | |
| if out is not None: | |
| transformer_depth_output.append(out[0]) | |
| else: | |
| transformer_depth_output.append(0) | |
| num_res_blocks.append(last_res_blocks) | |
| channel_mult.append(last_channel_mult) | |
| if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: | |
| transformer_depth_middle = transformer.count_blocks( | |
| state_dict_keys, | |
| "{}middle_block.1.transformer_blocks.".format(key_prefix) + "{}", | |
| ) | |
| elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys: | |
| transformer_depth_middle = -1 | |
| else: | |
| transformer_depth_middle = -2 | |
| unet_config["in_channels"] = in_channels | |
| unet_config["out_channels"] = out_channels | |
| unet_config["model_channels"] = model_channels | |
| unet_config["num_res_blocks"] = num_res_blocks | |
| unet_config["transformer_depth"] = transformer_depth | |
| unet_config["transformer_depth_output"] = transformer_depth_output | |
| unet_config["channel_mult"] = channel_mult | |
| unet_config["transformer_depth_middle"] = transformer_depth_middle | |
| unet_config["use_linear_in_transformer"] = use_linear_in_transformer | |
| unet_config["context_dim"] = context_dim | |
| if video_model: | |
| unet_config["extra_ff_mix_layer"] = True | |
| unet_config["use_spatial_context"] = True | |
| unet_config["merge_strategy"] = "learned_with_images" | |
| unet_config["merge_factor"] = 0.0 | |
| unet_config["video_kernel_size"] = [3, 1, 1] | |
| unet_config["use_temporal_resblock"] = True | |
| unet_config["use_temporal_attention"] = True | |
| unet_config["disable_temporal_crossattention"] = not video_model_cross | |
| else: | |
| unet_config["use_temporal_resblock"] = False | |
| unet_config["use_temporal_attention"] = False | |
| return unet_config | |
| def model_config_from_unet_config( | |
| unet_config: Dict[str, Any], state_dict: Optional[Dict[str, torch.Tensor]] = None | |
| ) -> Any: | |
| """#### Get the model configuration from a UNet configuration. | |
| #### Args: | |
| - `unet_config` (Dict[str, Any]): The UNet configuration. | |
| - `state_dict` (Optional[Dict[str, torch.Tensor]], optional): The state dictionary. Defaults to None. | |
| #### Returns: | |
| - `Any`: The model configuration. | |
| """ | |
| from modules.SD15 import SD15 | |
| for model_config in SD15.models: | |
| if model_config.matches(unet_config, state_dict): | |
| return model_config(unet_config) | |
| logging.error("no match {}".format(unet_config)) | |
| return None | |
| def model_config_from_unet( | |
| state_dict: Dict[str, torch.Tensor], | |
| unet_key_prefix: str, | |
| use_base_if_no_match: bool = False, | |
| ) -> Any: | |
| """#### Get the model configuration from a UNet state dictionary. | |
| #### Args: | |
| - `state_dict` (Dict[str, torch.Tensor]): The state dictionary. | |
| - `unet_key_prefix` (str): The UNet key prefix. | |
| - `use_base_if_no_match` (bool, optional): Whether to use the base configuration if no match is found. Defaults to False. | |
| #### Returns: | |
| - `Any`: The model configuration. | |
| """ | |
| unet_config = detect_unet_config(state_dict, unet_key_prefix) | |
| if unet_config is None: | |
| return None | |
| model_config = model_config_from_unet_config(unet_config, state_dict) | |
| return model_config | |
| def unet_dtype1( | |
| device: Optional[torch.device] = None, | |
| model_params: int = 0, | |
| supported_dtypes: List[torch.dtype] = [ | |
| torch.float16, | |
| torch.bfloat16, | |
| torch.float32, | |
| ], | |
| ) -> torch.dtype: | |
| """#### Get the dtype for the UNet model. | |
| #### Args: | |
| - `device` (Optional[torch.device], optional): The device. Defaults to None. | |
| - `model_params` (int, optional): The model parameters. Defaults to 0. | |
| - `supported_dtypes` (List[torch.dtype], optional): The supported dtypes. Defaults to [torch.float16, torch.bfloat16, torch.float32]. | |
| #### Returns: | |
| - `torch.dtype`: The dtype for the UNet model. | |
| """ | |
| return torch.float16 | |