Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import math | |
| from einops import rearrange | |
| from pdb import set_trace as st | |
| # from .dit_models import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed, modulate, FinalLayer | |
| from .dit_models_xformers import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed, modulate, FinalLayer | |
| # from .dit_models import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed, modulate, FinalLayer | |
| def modulate2(x, shift, scale): | |
| return x * (1 + scale) + shift | |
| class DiTBlock2(DiTBlock): | |
| """ | |
| A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. | |
| """ | |
| def __init__(self, hidden_size, num_heads, mlp_ratio=4, **block_kwargs): | |
| super().__init__(hidden_size, num_heads, mlp_ratio, **block_kwargs) | |
| def forward(self, x, c): | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation( | |
| c).chunk(6, dim=-1) | |
| # st() | |
| x = x + gate_msa * self.attn( | |
| modulate2(self.norm1(x), shift_msa, scale_msa)) | |
| x = x + gate_mlp * self.mlp( | |
| modulate2(self.norm2(x), shift_mlp, scale_mlp)) | |
| return x | |
| class FinalLayer2(FinalLayer): | |
| """ | |
| The final layer of DiT, basically the decoder_pred in MAE with adaLN. | |
| """ | |
| def __init__(self, hidden_size, patch_size, out_channels): | |
| super().__init__(hidden_size, patch_size, out_channels) | |
| def forward(self, x, c): | |
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) | |
| x = modulate2(self.norm_final(x), shift, scale) | |
| x = self.linear(x) | |
| return x | |
| class DiT2(DiT): | |
| # a conditional ViT | |
| def __init__(self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| roll_out=False, | |
| plane_n=3, | |
| return_all_layers=False, | |
| vit_blk=...): | |
| super().__init__(input_size, | |
| patch_size, | |
| in_channels, | |
| hidden_size, | |
| depth, | |
| num_heads, | |
| mlp_ratio, | |
| class_dropout_prob, | |
| num_classes, | |
| learn_sigma, | |
| mixing_logit_init, | |
| mixed_prediction, | |
| context_dim, | |
| roll_out, | |
| vit_blk=DiTBlock2, | |
| final_layer_blk=FinalLayer2) | |
| # no t and x embedder | |
| del self.x_embedder | |
| del self.t_embedder | |
| del self.final_layer | |
| torch.cuda.empty_cache() | |
| self.clip_text_proj = None | |
| self.plane_n = plane_n | |
| self.return_all_layers = return_all_layers | |
| def forward(self, c, *args, **kwargs): | |
| # return super().forward(x, timesteps, context, y, get_attr, **kwargs) | |
| """ | |
| Forward pass of DiT. | |
| c: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| """ | |
| x = self.pos_embed.repeat( | |
| c.shape[0], 1, 1) # (N, T, D), where T = H * W / patch_size ** 2 | |
| if self.return_all_layers: | |
| all_layers = [] | |
| # if context is not None: | |
| # c = context # B 3HW C | |
| for blk_idx, block in enumerate(self.blocks): | |
| if self.roll_out: | |
| if blk_idx % 2 == 0: # with-in plane self attention | |
| x = rearrange(x, 'b (n l) c -> (b n) l c ', n=self.plane_n) | |
| x = block(x, | |
| rearrange(c, | |
| 'b (n l) c -> (b n) l c ', | |
| n=self.plane_n)) # (N, T, D) | |
| # st() | |
| if self.return_all_layers: | |
| all_layers.append(x) | |
| else: # global attention | |
| x = rearrange(x, '(b n) l c -> b (n l) c ', n=self.plane_n) | |
| x = block(x, c) # (N, T, D) | |
| # st() | |
| if self.return_all_layers: | |
| # all merged into B dim | |
| all_layers.append( | |
| rearrange(x, | |
| 'b (n l) c -> (b n) l c', | |
| n=self.plane_n)) | |
| else: | |
| x = block(x, c) # (N, T, D) | |
| # x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) | |
| # if self.roll_out: # move n from L to B axis | |
| # x = rearrange(x, 'b (n l) c ->(b n) l c', n=3) | |
| # x = self.unpatchify(x) # (N, out_channels, H, W) | |
| # if self.roll_out: # move n from L to B axis | |
| # x = rearrange(x, '(b n) c h w -> b (n c) h w', n=3) | |
| if self.return_all_layers: | |
| return all_layers | |
| else: | |
| return x | |
| # class DiT2_DPT(DiT2): | |
| # def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4, class_dropout_prob=0.1, num_classes=1000, learn_sigma=True, mixing_logit_init=-3, mixed_prediction=True, context_dim=False, roll_out=False, plane_n=3, vit_blk=...): | |
| # super().__init__(input_size, patch_size, in_channels, hidden_size, depth, num_heads, mlp_ratio, class_dropout_prob, num_classes, learn_sigma, mixing_logit_init, mixed_prediction, context_dim, roll_out, plane_n, vit_blk) | |
| # self.return_all_layers = True | |
| ################################################################################# | |
| # DiT2 Configs # | |
| ################################################################################# | |
| def DiT2_XL_2(**kwargs): | |
| return DiT2(depth=28, | |
| hidden_size=1152, | |
| patch_size=2, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT2_XL_2_half(**kwargs): | |
| return DiT2(depth=28 // 2, | |
| hidden_size=1152, | |
| patch_size=2, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT2_XL_4(**kwargs): | |
| return DiT2(depth=28, | |
| hidden_size=1152, | |
| patch_size=4, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT2_XL_8(**kwargs): | |
| return DiT2(depth=28, | |
| hidden_size=1152, | |
| patch_size=8, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT2_L_2(**kwargs): | |
| return DiT2(depth=24, | |
| hidden_size=1024, | |
| patch_size=2, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT2_L_2_half(**kwargs): | |
| return DiT2(depth=24 // 2, | |
| hidden_size=1024, | |
| patch_size=2, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT2_L_4(**kwargs): | |
| return DiT2(depth=24, | |
| hidden_size=1024, | |
| patch_size=4, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT2_L_8(**kwargs): | |
| return DiT2(depth=24, | |
| hidden_size=1024, | |
| patch_size=8, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT2_B_2(**kwargs): | |
| return DiT2(depth=12, | |
| hidden_size=768, | |
| patch_size=2, | |
| num_heads=12, | |
| **kwargs) | |
| def DiT2_B_4(**kwargs): | |
| return DiT2(depth=12, | |
| hidden_size=768, | |
| patch_size=4, | |
| num_heads=12, | |
| **kwargs) | |
| def DiT2_B_8(**kwargs): | |
| return DiT2(depth=12, | |
| hidden_size=768, | |
| patch_size=8, | |
| num_heads=12, | |
| **kwargs) | |
| def DiT2_B_16(**kwargs): # ours cfg | |
| return DiT2(depth=12, | |
| hidden_size=768, | |
| patch_size=16, | |
| num_heads=12, | |
| **kwargs) | |
| def DiT2_S_2(**kwargs): | |
| return DiT2(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) | |
| def DiT2_S_4(**kwargs): | |
| return DiT2(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) | |
| def DiT2_S_8(**kwargs): | |
| return DiT2(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) | |
| DiT2_models = { | |
| 'DiT2-XL/2': DiT2_XL_2, | |
| 'DiT2-XL/2/half': DiT2_XL_2_half, | |
| 'DiT2-XL/4': DiT2_XL_4, | |
| 'DiT2-XL/8': DiT2_XL_8, | |
| 'DiT2-L/2': DiT2_L_2, | |
| 'DiT2-L/2/half': DiT2_L_2_half, | |
| 'DiT2-L/4': DiT2_L_4, | |
| 'DiT2-L/8': DiT2_L_8, | |
| 'DiT2-B/2': DiT2_B_2, | |
| 'DiT2-B/4': DiT2_B_4, | |
| 'DiT2-B/8': DiT2_B_8, | |
| 'DiT2-B/16': DiT2_B_16, | |
| 'DiT2-S/2': DiT2_S_2, | |
| 'DiT2-S/4': DiT2_S_4, | |
| 'DiT2-S/8': DiT2_S_8, | |
| } | |