Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Mostly copy-paste from timm library. | |
| https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py | |
| """ | |
| from copy import deepcopy | |
| import math | |
| from functools import partial | |
| from sympy import flatten | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor, pixel_shuffle | |
| from einops import rearrange, repeat | |
| from einops.layers.torch import Rearrange | |
| from torch.nn.modules import GELU | |
| # from vit.vision_transformer import Conv3DCrossAttentionBlock | |
| from .utils import trunc_normal_ | |
| from pdb import set_trace as st | |
| # import apex | |
| try: | |
| from apex.normalization import FusedRMSNorm as RMSNorm | |
| except: | |
| # from dit.norm import RMSNorm | |
| from dit.norm import RMSNorm | |
| # from apex.normalization import FusedLayerNorm as LayerNorm | |
| try: | |
| from xformers.ops import memory_efficient_attention, unbind, fmha | |
| from xformers.ops import MemoryEfficientAttentionFlashAttentionOp | |
| # from xformers.ops import RMSNorm | |
| XFORMERS_AVAILABLE = True | |
| except ImportError: | |
| # logger.warning("xFormers not available") | |
| XFORMERS_AVAILABLE = False | |
| class Attention(nn.Module): | |
| def __init__(self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| attn_drop=0., | |
| proj_drop=0., | |
| enable_rmsnorm=False, | |
| qk_norm=False,): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = qk_scale or head_dim**-0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| # https://github.com/huggingface/pytorch-image-models/blob/5dce71010174ad6599653da4e8ba37fd5f9fa572/timm/models/vision_transformer.py#L79C1-L80C78 | |
| self.q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-5) if qk_norm else nn.Identity() # sd-3 | |
| self.k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-5) if qk_norm else nn.Identity() | |
| # if qk_norm: | |
| # self.q_norm = LayerNorm(dim, eps=1e-5) | |
| # self.k_norm = LayerNorm(dim, eps=1e-5) | |
| self.qk_norm = qk_norm | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, | |
| C // self.num_heads).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| # return x, attn | |
| return x | |
| class MemEffAttention(Attention): | |
| def forward(self, x: Tensor, attn_bias=None) -> Tensor: | |
| if not XFORMERS_AVAILABLE: | |
| assert attn_bias is None, "xFormers is required for nested tensors usage" | |
| return super().forward(x) | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
| q, k, v = unbind(qkv, 2) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) # if not bf16, no flash-attn here. | |
| # x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) # force flash attention | |
| x = x.reshape([B, N, C]) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class MemEffCrossAttention(MemEffAttention): | |
| # for cross attention, where context serves as k and v | |
| def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0, proj_drop=0): | |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) | |
| del self.qkv | |
| self.q = nn.Linear(dim, dim * 1, bias=qkv_bias) | |
| self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) | |
| def forward(self, x: Tensor, context: Tensor, attn_bias=None) -> Tensor: | |
| if not XFORMERS_AVAILABLE: | |
| assert attn_bias is None, "xFormers is required for nested tensors usage" | |
| return super().forward(x) | |
| B, N, C = x.shape | |
| # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
| q = self.q(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
| kv = self.kv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
| k, v = unbind(kv, 2) | |
| # x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) | |
| x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) | |
| x = x.reshape([B, N, C]) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| # https://github.com/IBM/CrossViT/blob/main/models/crossvit.py | |
| class CrossAttention(nn.Module): | |
| def __init__(self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| attn_drop=0., | |
| proj_drop=0.): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |
| self.scale = qk_scale or head_dim**-0.5 | |
| self.wq = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.wk = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.wv = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| q = self.wq(x[:, | |
| 0:1, ...]).reshape(B, 1, self.num_heads, | |
| C // self.num_heads).permute( | |
| 0, 2, 1, | |
| 3) # B1C -> B1H(C/H) -> BH1(C/H) | |
| k = self.wk(x).reshape(B, N, | |
| self.num_heads, C // self.num_heads).permute( | |
| 0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) | |
| v = self.wv(x).reshape(B, N, | |
| self.num_heads, C // self.num_heads).permute( | |
| 0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) | |
| attn = (q @ k.transpose( | |
| -2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape( | |
| B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class Conv3D_Aware_CrossAttention(nn.Module): | |
| def __init__(self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| attn_drop=0., | |
| proj_drop=0.): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |
| self.scale = qk_scale or head_dim**-0.5 | |
| self.wq = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.wk = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.wv = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x): | |
| B, group_size, N, C = x.shape # B 3 N C | |
| p = int(N**0.5) # patch size | |
| assert p**2 == N, 'check input dim, no [cls] needed here' | |
| assert group_size == 3, 'designed for triplane here' | |
| x = x.reshape(B, group_size, p, p, C) # expand patch token dim | |
| # * init qkv | |
| # q = torch.empty(B * group_size * N, | |
| # 1, | |
| # self.num_heads, | |
| # C // self.num_heads, | |
| # device=x.device).permute(0, 2, 1, 3) | |
| # k = torch.empty(B * group_size * N, | |
| # 2 * p, | |
| # self.num_heads, | |
| # C // self.num_heads, | |
| # device=x.device).permute(0, 2, 1, 3) | |
| # v = torch.empty_like(k) | |
| q_x = torch.empty( | |
| B * group_size * N, | |
| 1, | |
| # self.num_heads, | |
| # C // self.num_heads, | |
| C, | |
| device=x.device) | |
| k_x = torch.empty( | |
| B * group_size * N, | |
| 2 * p, | |
| # self.num_heads, | |
| # C // self.num_heads, | |
| C, | |
| device=x.device) | |
| v_x = torch.empty_like(k_x) | |
| # ! refer to the following plane order | |
| # N, M, _ = coordinates.shape | |
| # xy_coords = coordinates[..., [0, 1]] | |
| # yz_coords = coordinates[..., [1, 2]] | |
| # zx_coords = coordinates[..., [2, 0]] | |
| # return torch.stack([xy_coords, yz_coords, zx_coords], | |
| # dim=1).reshape(N * 3, M, 2) | |
| index_i, index_j = torch.meshgrid(torch.arange(0, p), | |
| torch.arange(0, p), | |
| indexing='ij') # 16*16 | |
| index_mesh_grid = torch.stack([index_i, index_j], 0).to( | |
| x.device).unsqueeze(0).repeat_interleave(B, | |
| 0).reshape(B, 2, p, | |
| p) # B 2 p p. | |
| for i in range(group_size): | |
| q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( | |
| 0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C | |
| # TODO, how to batchify gather ops? | |
| plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + | |
| 1] # B 1 p p C | |
| plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] | |
| assert plane_yz.shape == plane_zx.shape == ( | |
| B, 1, p, p, C), 'check sub plane dimensions' | |
| pooling_plane_yz = torch.gather( | |
| plane_yz, | |
| dim=2, | |
| index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand( | |
| -1, -1, -1, p, | |
| C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C | |
| pooling_plane_zx = torch.gather( | |
| plane_zx, | |
| dim=3, | |
| index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand( | |
| -1, -1, p, -1, | |
| C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C | |
| k_x[B * i * N:B * (i + 1) * | |
| N] = v_x[B * i * N:B * (i + 1) * N] = torch.cat( | |
| [pooling_plane_yz, pooling_plane_zx], | |
| dim=2).reshape(B * N, 2 * p, | |
| C) # B 256 2 16 C => (B*256) 2*16 C | |
| # q[B * i * N: B * (i+1) * N] = self.wq(q_x).reshape(B*N, 1, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) | |
| # k[B * i * N: B * (i+1) * N] = self.wk(k_x).reshape(B*N, 2*p, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) | |
| # v[B * i * N: B * (i+1) * N] = self.wv(v_x).reshape(B*N, 2*p, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3) | |
| q = self.wq(q_x).reshape(B * group_size * N, 1, | |
| self.num_heads, C // self.num_heads).permute( | |
| 0, 2, 1, | |
| 3) # merge num_heads into Batch dimention | |
| k = self.wk(k_x).reshape(B * group_size * N, 2 * p, self.num_heads, | |
| C // self.num_heads).permute(0, 2, 1, 3) | |
| v = self.wv(v_x).reshape(B * group_size * N, 2 * p, self.num_heads, | |
| C // self.num_heads).permute(0, 2, 1, 3) | |
| attn = (q @ k.transpose( | |
| -2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N, N=2p here | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape( | |
| B * 3 * N, 1, | |
| C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| # reshape x back | |
| x = x.reshape(B, 3, N, C) | |
| return x | |
| class xformer_Conv3D_Aware_CrossAttention(nn.Module): | |
| # https://github.dev/facebookresearch/dinov2 | |
| def __init__(self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| attn_drop=0., | |
| proj_drop=0.): | |
| super().__init__() | |
| # https://pytorch.org/blog/accelerated-generative-diffusion-models/ | |
| self.num_heads = num_heads | |
| self.wq = nn.Linear(dim, dim * 1, bias=qkv_bias) | |
| self.w_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| self.index_mesh_grid = None | |
| def forward(self, x, attn_bias=None): | |
| B, group_size, N, C = x.shape # B 3 N C | |
| p = int(N**0.5) # patch size | |
| assert p**2 == N, 'check input dim, no [cls] needed here' | |
| assert group_size == 3, 'designed for triplane here' | |
| x = x.reshape(B, group_size, p, p, C) # expand patch token dim | |
| q_x = torch.empty(B * group_size * N, 1, C, device=x.device) | |
| context = torch.empty(B * group_size * N, 2 * p, C, | |
| device=x.device) # k_x=v_x | |
| if self.index_mesh_grid is None: # further accelerate | |
| index_i, index_j = torch.meshgrid(torch.arange(0, p), | |
| torch.arange(0, p), | |
| indexing='ij') # 16*16 | |
| index_mesh_grid = torch.stack([index_i, index_j], 0).to( | |
| x.device).unsqueeze(0).repeat_interleave(B, 0).reshape( | |
| B, 2, p, p) # B 2 p p. | |
| self.index_mesh_grid = index_mesh_grid[0:1] | |
| else: | |
| index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave( | |
| B, 0) | |
| assert index_mesh_grid.shape == ( | |
| B, 2, p, p), 'check index_mesh_grid dimension' | |
| for i in range(group_size): | |
| q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( | |
| 0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C | |
| # TODO, how to batchify gather ops? | |
| plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + | |
| 1] # B 1 p p C | |
| plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] | |
| assert plane_yz.shape == plane_zx.shape == ( | |
| B, 1, p, p, C), 'check sub plane dimensions' | |
| pooling_plane_yz = torch.gather( | |
| plane_yz, | |
| dim=2, | |
| index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand( | |
| -1, -1, -1, p, | |
| C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C | |
| pooling_plane_zx = torch.gather( | |
| plane_zx, | |
| dim=3, | |
| index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand( | |
| -1, -1, p, -1, | |
| C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C | |
| context[B * i * N:B * (i + 1) * N] = torch.cat( | |
| [pooling_plane_yz, pooling_plane_zx], | |
| dim=2).reshape(B * N, 2 * p, | |
| C) # B 256 2 16 C => (B*256) 2*16 C | |
| # B, N, C = x.shape | |
| q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads, | |
| C // self.num_heads) | |
| kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2, | |
| self.num_heads, C // self.num_heads) | |
| k, v = unbind(kv, 2) | |
| x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) | |
| # x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) | |
| x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class xformer_Conv3D_Aware_CrossAttention_xygrid( | |
| xformer_Conv3D_Aware_CrossAttention): | |
| """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention | |
| """ | |
| def __init__(self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0): | |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, | |
| proj_drop) | |
| def forward(self, x, attn_bias=None): | |
| B, group_size, N, C = x.shape # B 3 N C | |
| p = int(N**0.5) # patch size | |
| assert p**2 == N, 'check input dim, no [cls] needed here' | |
| assert group_size == 3, 'designed for triplane here' | |
| x = x.reshape(B, group_size, p, p, C) # expand patch token dim | |
| q_x = torch.empty(B * group_size * N, 1, C, device=x.device) | |
| context = torch.empty(B * group_size * N, 2 * p, C, | |
| device=x.device) # k_x=v_x | |
| if self.index_mesh_grid is None: # further accelerate | |
| index_u, index_v = torch.meshgrid( | |
| torch.arange(0, p), torch.arange(0, p), | |
| indexing='xy') # ! switch to 'xy' here to match uv coordinate | |
| index_mesh_grid = torch.stack([index_u, index_v], 0).to( | |
| x.device).unsqueeze(0).repeat_interleave(B, 0).reshape( | |
| B, 2, p, p) # B 2 p p. | |
| self.index_mesh_grid = index_mesh_grid[0:1] | |
| else: | |
| index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave( | |
| B, 0) | |
| assert index_mesh_grid.shape == ( | |
| B, 2, p, p), 'check index_mesh_grid dimension' | |
| for i in range(group_size): | |
| q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( | |
| 0, 2, 3, 1, 4).reshape(B * N, 1, C) # B 1 p p C -> B*N, 1, C | |
| # TODO, how to batchify gather ops? | |
| plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + | |
| 1] # B 1 p p C | |
| plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] | |
| assert plane_yz.shape == plane_zx.shape == ( | |
| B, 1, p, p, C), 'check sub plane dimensions' | |
| pooling_plane_yz = torch.gather( | |
| plane_yz, | |
| dim=2, | |
| index=index_mesh_grid[:, 1:2].reshape(B, 1, N, 1, 1).expand( | |
| -1, -1, -1, p, | |
| C)).permute(0, 2, 1, 3, 4) # B 1 256 16 C => B 256 1 16 C | |
| pooling_plane_zx = torch.gather( | |
| plane_zx, | |
| dim=3, | |
| index=index_mesh_grid[:, 0:1].reshape(B, 1, 1, N, 1).expand( | |
| -1, -1, p, -1, | |
| C)).permute(0, 3, 1, 2, 4) # B 1 16 256 C => B 256 1 16 C | |
| context[B * i * N:B * (i + 1) * N] = torch.cat( | |
| [pooling_plane_yz, pooling_plane_zx], | |
| dim=2).reshape(B * N, 2 * p, | |
| C) # B 256 2 16 C => (B*256) 2*16 C | |
| # B, N, C = x.shape | |
| q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads, | |
| C // self.num_heads) | |
| kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2, | |
| self.num_heads, C // self.num_heads) | |
| k, v = unbind(kv, 2) | |
| x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) | |
| # x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=MemoryEfficientAttentionFlashAttentionOp) | |
| x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( | |
| xformer_Conv3D_Aware_CrossAttention_xygrid): | |
| def __init__(self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| attn_drop=0, | |
| proj_drop=0): | |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, | |
| proj_drop) | |
| def forward(self, x, attn_bias=None): | |
| # ! split x: B N C into B 3 N C//3 | |
| B, N, C = x.shape | |
| x = x.reshape(B, N, C // 3, 3).permute(0, 3, 1, | |
| 2) # B N C 3 -> B 3 N C | |
| x_out = super().forward(x, attn_bias) # B 3 N C | |
| x_out = x_out.permute(0, 2, 3, 1)# B 3 N C -> B N C 3 | |
| x_out = x_out.reshape(*x_out.shape[:2], -1) # B N C 3 -> B N C3 | |
| return x_out.contiguous() | |
| class self_cross_attn(nn.Module): | |
| def __init__(self, dino_attn, cross_attn, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| self.dino_attn = dino_attn | |
| self.cross_attn = cross_attn | |
| def forward(self, x_norm): | |
| y = self.dino_attn(x_norm) + x_norm | |
| return self.cross_attn(y) # will add x in the original code | |
| # class RodinRollOutConv(nn.Module): | |
| # """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention | |
| # Use Group Conv | |
| # """ | |
| # def __init__(self, in_chans, out_chans=None): | |
| # super().__init__() | |
| # # input: B 3C H W | |
| # if out_chans is None: | |
| # out_chans = in_chans | |
| # self.roll_out_convs = nn.Conv2d(in_chans, | |
| # out_chans, | |
| # kernel_size=3, | |
| # groups=3, | |
| # padding=1) | |
| # def forward(self, x): | |
| # return self.roll_out_convs(x) | |
| class RodinRollOutConv3D(nn.Module): | |
| """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention | |
| """ | |
| def __init__(self, in_chans, out_chans=None): | |
| super().__init__() | |
| if out_chans is None: | |
| out_chans = in_chans | |
| self.out_chans = out_chans // 3 | |
| self.roll_out_convs = nn.Conv2d(in_chans, | |
| self.out_chans, | |
| kernel_size=3, | |
| padding=1) | |
| def forward(self, x): | |
| # todo, reshape before input? | |
| B, C3, p, p = x.shape # B 3C H W | |
| C = C3 // 3 | |
| group_size = C3 // C | |
| assert group_size == 3 | |
| x = x.reshape(B, 3, C, p, p) | |
| roll_out_x = torch.empty(B, group_size * C, p, 3 * p, | |
| device=x.device) # B, 3C, H, 3W | |
| for i in range(group_size): | |
| plane_xy = x[:, i] # B C H W | |
| # TODO, simply do the average pooling? | |
| plane_yz_pooling = x[:, (i + 1) % group_size].mean( | |
| dim=-1, keepdim=True).repeat_interleave( | |
| p, dim=-1) # B C H W -> B C H 1 -> B C H W, reduce z dim | |
| plane_zx_pooling = x[:, (i + 2) % group_size].mean( | |
| dim=-2, keepdim=True).repeat_interleave( | |
| p, dim=-2) # B C H W -> B C 1 W -> B C H W, reduce z dim | |
| roll_out_x[..., i * p:(i + 1) * p] = torch.cat( | |
| [plane_xy, plane_yz_pooling, plane_zx_pooling], | |
| 1) # fill in the 3W dim | |
| x = self.roll_out_convs(roll_out_x) # B C H 3W | |
| x = x.reshape(B, self.out_chans, p, 3, p) | |
| x = x.permute(0, 3, 1, 2, 4).reshape(B, 3 * self.out_chans, p, | |
| p) # B 3C H W | |
| return x | |
| class RodinRollOutConv3D_GroupConv(nn.Module): | |
| """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention | |
| """ | |
| def __init__(self, | |
| in_chans, | |
| out_chans=None, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1): | |
| super().__init__() | |
| if out_chans is None: | |
| out_chans = in_chans | |
| self.roll_out_convs = nn.Conv2d( | |
| in_chans * 3, | |
| out_chans, | |
| kernel_size=kernel_size, | |
| groups=3, # B 9C H W | |
| stride=stride, | |
| padding=padding) | |
| # @torch.autocast(device_type='cuda') | |
| def forward(self, x): | |
| # todo, reshape before input? | |
| B, C3, p, p = x.shape # B 3C H W | |
| C = C3 // 3 | |
| group_size = C3 // C | |
| assert group_size == 3 | |
| x = x.reshape(B, 3, C, p, p) | |
| roll_out_x = torch.empty(B, group_size * C * 3, p, p, | |
| device=x.device) # B, 3C, H, 3W | |
| for i in range(group_size): | |
| plane_xy = x[:, i] # B C H W | |
| # # TODO, simply do the average pooling? | |
| plane_yz_pooling = x[:, (i + 1) % group_size].mean( | |
| dim=-1, keepdim=True).repeat_interleave( | |
| p, dim=-1) # B C H W -> B C H 1 -> B C H W, reduce z dim | |
| plane_zx_pooling = x[:, (i + 2) % group_size].mean( | |
| dim=-2, keepdim=True).repeat_interleave( | |
| p, dim=-2) # B C H W -> B C 1 W -> B C H W, reduce z dim | |
| roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat( | |
| [plane_xy, plane_yz_pooling, plane_zx_pooling], | |
| 1) # fill in the 3W dim | |
| # ! directly cat, avoid intermediate vars | |
| # ? why OOM | |
| # roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat( | |
| # [ | |
| # x[:, i], | |
| # x[:, (i + 1) % group_size].mean( | |
| # dim=-1, keepdim=True).repeat_interleave(p, dim=-1), | |
| # x[:, (i + 2) % group_size].mean( | |
| # dim=-2, keepdim=True).repeat_interleave( | |
| # p, dim=-2 | |
| # ) # B C H W -> B C 1 W -> B C H W, reduce z dim | |
| # ], | |
| # 1) # fill in the 3C dim | |
| x = self.roll_out_convs(roll_out_x) # B 3C H W | |
| return x | |
| class RodinRollOut_GroupConv_noConv3D(nn.Module): | |
| """only roll out and do Conv on individual planes | |
| """ | |
| def __init__(self, | |
| in_chans, | |
| out_chans=None, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1): | |
| super().__init__() | |
| if out_chans is None: | |
| out_chans = in_chans | |
| self.roll_out_inplane_conv = nn.Conv2d( | |
| in_chans, | |
| out_chans, | |
| kernel_size=kernel_size, | |
| groups=3, # B 3C H W | |
| stride=stride, | |
| padding=padding) | |
| def forward(self, x): | |
| x = self.roll_out_inplane_conv(x) # B 3C H W | |
| return x | |
| # class RodinConv3D_SynthesisLayer_withact(nn.Module): | |
| # def __init__(self, in_chans, out_chans) -> None: | |
| # super().__init__() | |
| # self.act = nn.LeakyReLU(inplace=True) | |
| # self.conv = nn.Sequential( | |
| # RodinRollOutConv3D_GroupConv(in_chans, out_chans), | |
| # nn.LeakyReLU(inplace=True), | |
| # ) | |
| # if in_chans != out_chans: | |
| # self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. | |
| # else: | |
| # self.short_cut = None | |
| # def forward(self, feats): | |
| # if self.short_cut is not None: | |
| # res_feats = self.short_cut(feats) | |
| # else: | |
| # res_feats = feats | |
| # # return res_feats + self.conv(feats) | |
| # feats = res_feats + self.conv(feats) | |
| # return self.act(feats) # as in resnet, add an act before return | |
| class RodinConv3D_SynthesisLayer_mlp_unshuffle_as_residual(nn.Module): | |
| def __init__(self, in_chans, out_chans) -> None: | |
| super().__init__() | |
| self.act = nn.LeakyReLU(inplace=True) | |
| self.conv = nn.Sequential( | |
| RodinRollOutConv3D_GroupConv(in_chans, out_chans), | |
| nn.LeakyReLU(inplace=True), | |
| ) | |
| self.out_chans = out_chans | |
| if in_chans != out_chans: | |
| # self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. | |
| self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W | |
| in_chans // 3, # 144 / 3 = 48 | |
| out_chans // 3 * 4 * 4, # 32 * 16 | |
| bias=True) # decoder to pat | |
| # RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. | |
| else: | |
| self.short_cut = None | |
| def shortcut_unpatchify_triplane(self, | |
| x, | |
| p=None, | |
| unpatchify_out_chans=None): | |
| """separate triplane version; x shape: B (3*257) 768 | |
| """ | |
| assert self.short_cut is not None | |
| # B, L, C = x.shape | |
| B, C3, h, w = x.shape | |
| assert h == w | |
| L = h * w | |
| x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3, | |
| 1) # (B, 3, L // 3, C) | |
| x = self.short_cut(x) | |
| p = h * 4 | |
| x = x.reshape(shape=(B, 3, h, w, p, p, unpatchify_out_chans)) | |
| x = torch.einsum('ndhwpqc->ndchpwq', | |
| x) # nplanes, C order in the renderer.py | |
| x = x.reshape(shape=(B, 3 * self.out_chans, h * p, h * p)) | |
| return x | |
| def forward(self, feats): | |
| if self.short_cut is not None: | |
| res_feats = self.shortcut_unpatchify_triplane(feats) | |
| else: | |
| res_feats = feats | |
| # return res_feats + self.conv(feats) | |
| feats = res_feats + self.conv(feats) | |
| return self.act(feats) # as in resnet, add an act before return | |
| # class RodinConv3D_SynthesisLayer(nn.Module): | |
| # def __init__(self, in_chans, out_chans) -> None: | |
| # super().__init__() | |
| # self.act = nn.LeakyReLU(inplace=True) | |
| # self.conv = nn.Sequential( | |
| # RodinRollOutConv3D_GroupConv(in_chans, out_chans), | |
| # nn.LeakyReLU(inplace=True), | |
| # ) | |
| # if in_chans != out_chans: | |
| # self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) # PSNR 13 first iteration. | |
| # else: | |
| # self.short_cut = None | |
| # def forward(self, feats): | |
| # if self.short_cut is not None: | |
| # res_feats = self.short_cut(feats) | |
| # else: | |
| # res_feats = feats | |
| # # return res_feats + self.conv(feats) | |
| # feats = res_feats + self.conv(feats) | |
| # # return self.act(feats) # as in resnet, add an act before return | |
| # return feats # ! old behaviour, no act | |
| # previous worked version | |
| class RodinConv3D_SynthesisLayer(nn.Module): | |
| def __init__(self, in_chans, out_chans) -> None: | |
| super().__init__() | |
| # x2 SR + 1x1 Conv Residual BLK | |
| # self.conv3D = RodinRollOutConv3D(in_chans, out_chans) | |
| self.act = nn.LeakyReLU(inplace=True) | |
| self.conv = nn.Sequential( | |
| RodinRollOutConv3D_GroupConv(in_chans, out_chans), | |
| nn.LeakyReLU(inplace=True), | |
| ) | |
| if in_chans != out_chans: | |
| self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) | |
| else: | |
| self.short_cut = None | |
| def forward(self, feats): | |
| feats_out = self.conv(feats) | |
| if self.short_cut is not None: | |
| # ! failed below | |
| feats_out = self.short_cut( | |
| feats | |
| ) + feats_out # ! only difference here, no act() compared with baseline | |
| # feats_out = self.act(self.short_cut(feats)) + feats_out # ! only difference here, no act() compared with baseline | |
| else: | |
| feats_out = feats_out + feats | |
| return feats_out | |
| class RodinRollOutConv3DSR2X(nn.Module): | |
| def __init__(self, in_chans, **kwargs) -> None: | |
| super().__init__() | |
| self.conv3D = RodinRollOutConv3D_GroupConv(in_chans) | |
| # self.conv3D = RodinRollOutConv3D(in_chans) | |
| self.act = nn.LeakyReLU(inplace=True) | |
| self.input_resolution = 224 | |
| def forward(self, x): | |
| # x: B 3 112*112 C | |
| B, C3, p, p = x.shape # after unpachify triplane | |
| C = C3 // 3 | |
| group_size = C3 // C | |
| assert group_size == 3 | |
| # p = int(N**0.5) # patch size | |
| # assert p**2 == N, 'check input dim, no [cls] needed here' | |
| assert group_size == 3, 'designed for triplane here' | |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
| p) # B 3 C N -> B 3C h W | |
| if x.shape[-1] != self.input_resolution: | |
| x = torch.nn.functional.interpolate(x, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=True) | |
| x = x + self.conv3D(x) | |
| return x | |
| class RodinRollOutConv3DSR4X_lite(nn.Module): | |
| def __init__(self, in_chans, input_resolutiopn=256, **kwargs) -> None: | |
| super().__init__() | |
| self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans) | |
| self.conv3D_1 = RodinRollOutConv3D_GroupConv(in_chans) | |
| self.act = nn.LeakyReLU(inplace=True) | |
| self.input_resolution = input_resolutiopn | |
| def forward(self, x): | |
| # x: B 3 112*112 C | |
| B, C3, p, p = x.shape # after unpachify triplane | |
| C = C3 // 3 | |
| group_size = C3 // C | |
| assert group_size == 3 | |
| # p = int(N**0.5) # patch size | |
| # assert p**2 == N, 'check input dim, no [cls] needed here' | |
| assert group_size == 3, 'designed for triplane here' | |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
| p) # B 3 C N -> B 3C h W | |
| if x.shape[-1] != self.input_resolution: | |
| x = torch.nn.functional.interpolate(x, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=True) | |
| # ! still not convering, not bug here? | |
| # x = x + self.conv3D_0(x) | |
| # x = x + self.conv3D_1(x) | |
| x = x + self.act(self.conv3D_0(x)) | |
| x = x + self.act(self.conv3D_1(x)) | |
| # TODO: which is better, bilinear + conv or PixelUnshuffle? | |
| return x | |
| # class RodinConv3D2X_lite_mlp_as_residual(nn.Module): | |
| # """lite 4X version, with MLP unshuffle to change the dimention | |
| # """ | |
| # def __init__(self, in_chans, out_chans, input_resolution=256) -> None: | |
| # super().__init__() | |
| # self.act = nn.LeakyReLU(inplace=True) | |
| # self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans) | |
| # self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans) | |
| # self.act = nn.LeakyReLU(inplace=True) | |
| # self.input_resolution = input_resolution | |
| # self.out_chans = out_chans | |
| # if in_chans != out_chans: # ! only change the dimension | |
| # self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W | |
| # in_chans//3, # 144 / 3 = 48 | |
| # out_chans//3, # 32 * 16 | |
| # bias=True) # decoder to pat | |
| # else: | |
| # self.short_cut = None | |
| # def shortcut_unpatchify_triplane(self, x, p=None): | |
| # """separate triplane version; x shape: B (3*257) 768 | |
| # """ | |
| # assert self.short_cut is not None | |
| # # B, L, C = x.shape | |
| # B, C3, h, w = x.shape | |
| # assert h == w | |
| # L = h*w | |
| # x = x.reshape(B, C3//3, 3, L).permute(0,2,3,1) # (B, 3, L // 3, C_in) | |
| # x = self.short_cut(x) # B 3 L//3 C_out | |
| # x = x.permute(0,1,3,2) # B 3 C_out L//3 | |
| # x = x.reshape(shape=(B, self.out_chans, h, w)) | |
| # # directly resize to the target, no unpatchify here since no 3D ViT is included here | |
| # if w != self.input_resolution: | |
| # x = torch.nn.functional.interpolate(x, # 4X SR | |
| # size=(self.input_resolution, | |
| # self.input_resolution), | |
| # mode='bilinear', | |
| # align_corners=False, | |
| # antialias=True) | |
| # return x | |
| # def forward(self, x): | |
| # # x: B 3 112*112 C | |
| # B, C3, p, p = x.shape # after unpachify triplane | |
| # C = C3 // 3 | |
| # if self.short_cut is not None: | |
| # res_feats = self.shortcut_unpatchify_triplane(x) | |
| # else: | |
| # res_feats = x | |
| # """following forward code copied from lite4x version | |
| # """ | |
| # x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
| # p) # B 3 C N -> B 3C h W | |
| # if x.shape[-1] != self.input_resolution: | |
| # x = torch.nn.functional.interpolate(x, # 4X SR | |
| # size=(self.input_resolution, | |
| # self.input_resolution), | |
| # mode='bilinear', | |
| # align_corners=False, | |
| # antialias=True) | |
| # x = res_feats + self.act(self.conv3D_0(x)) | |
| # x = x + self.act(self.conv3D_1(x)) | |
| # return x | |
| class RodinConv3D4X_lite_mlp_as_residual(nn.Module): | |
| """lite 4X version, with MLP unshuffle to change the dimention | |
| """ | |
| def __init__(self, | |
| in_chans, | |
| out_chans, | |
| input_resolution=256, | |
| interp_mode='bilinear', | |
| bcg_triplane=False) -> None: | |
| super().__init__() | |
| self.interp_mode = interp_mode | |
| self.act = nn.LeakyReLU(inplace=True) | |
| self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans) | |
| self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans) | |
| self.bcg_triplane = bcg_triplane | |
| if bcg_triplane: | |
| self.conv3D_1_bg = RodinRollOutConv3D_GroupConv( | |
| out_chans, out_chans) | |
| self.act = nn.LeakyReLU(inplace=True) | |
| self.input_resolution = input_resolution | |
| self.out_chans = out_chans | |
| if in_chans != out_chans: # ! only change the dimension | |
| self.short_cut = nn.Linear( # B 3C H W -> B 3C 4H 4W | |
| in_chans // 3, # 144 / 3 = 48 | |
| out_chans // 3, # 32 * 16 | |
| bias=True) # decoder to pat | |
| else: | |
| self.short_cut = None | |
| def shortcut_unpatchify_triplane(self, x, p=None): | |
| """separate triplane version; x shape: B (3*257) 768 | |
| """ | |
| assert self.short_cut is not None | |
| B, C3, h, w = x.shape | |
| assert h == w | |
| L = h * w | |
| x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3, | |
| 1) # (B, 3, L // 3, C_in) | |
| x = self.short_cut(x) # B 3 L//3 C_out | |
| x = x.permute(0, 1, 3, 2) # B 3 C_out L//3 | |
| x = x.reshape(shape=(B, self.out_chans, h, w)) | |
| # directly resize to the target, no unpatchify here since no 3D ViT is included here | |
| if w != self.input_resolution: | |
| x = torch.nn.functional.interpolate( | |
| x, # 4X SR | |
| size=(self.input_resolution, self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=True) | |
| return x | |
| def interpolate(self, feats): | |
| if self.interp_mode == 'bilinear': | |
| return torch.nn.functional.interpolate( | |
| feats, # 4X SR | |
| size=(self.input_resolution, self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=True) | |
| else: | |
| return torch.nn.functional.interpolate( | |
| feats, # 4X SR | |
| size=(self.input_resolution, self.input_resolution), | |
| mode='nearest', | |
| ) | |
| def forward(self, x): | |
| # x: B 3 112*112 C | |
| B, C3, p, p = x.shape # after unpachify triplane | |
| C = C3 // 3 | |
| if self.short_cut is not None: | |
| res_feats = self.shortcut_unpatchify_triplane(x) | |
| else: | |
| res_feats = x | |
| if res_feats.shape[-1] != self.input_resolution: | |
| res_feats = self.interpolate(res_feats) | |
| """following forward code copied from lite4x version | |
| """ | |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
| p) # B 3 C N -> B 3C h W | |
| if x.shape[-1] != self.input_resolution: | |
| x = self.interpolate(x) | |
| x0 = res_feats + self.act(self.conv3D_0(x)) # the base feature | |
| x = x0 + self.act(self.conv3D_1(x0)) | |
| if self.bcg_triplane: | |
| x_bcg = x0 + self.act(self.conv3D_1_bg(x0)) | |
| return torch.cat([x, x_bcg], 1) | |
| else: | |
| return x | |
| class RodinConv3D4X_lite_mlp_as_residual_litev2( | |
| RodinConv3D4X_lite_mlp_as_residual): | |
| def __init__(self, | |
| in_chans, | |
| out_chans, | |
| num_feat=128, | |
| input_resolution=256, | |
| interp_mode='bilinear', | |
| bcg_triplane=False) -> None: | |
| super().__init__(in_chans, out_chans, input_resolution, interp_mode, | |
| bcg_triplane) | |
| self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, in_chans) | |
| self.conv_before_upsample = RodinRollOut_GroupConv_noConv3D( | |
| in_chans, num_feat * 3) | |
| self.conv3D_1 = RodinRollOut_GroupConv_noConv3D( | |
| num_feat * 3, num_feat * 3) | |
| self.conv_last = RodinRollOut_GroupConv_noConv3D( | |
| num_feat * 3, out_chans) | |
| self.short_cut = None | |
| def forward(self, x): | |
| # x: B 3 112*112 C | |
| B, C3, p, p = x.shape # after unpachify triplane | |
| C = C3 // 3 | |
| # if self.short_cut is not None: | |
| # res_feats = self.shortcut_unpatchify_triplane(x) | |
| # else: | |
| # res_feats = x | |
| # if res_feats.shape[-1] != self.input_resolution: | |
| # res_feats = self.interpolate(res_feats) | |
| """following forward code copied from lite4x version | |
| """ | |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
| p) # B 3 C N -> B 3C h W | |
| x = x + self.conv3D_0(x) # the base feature | |
| x = self.act(self.conv_before_upsample(x)) | |
| # if x.shape[-1] != self.input_resolution: | |
| x = self.conv_last(self.act(self.conv3D_1(self.interpolate(x)))) | |
| return x | |
| class RodinConv3D4X_lite_mlp_as_residual_lite( | |
| RodinConv3D4X_lite_mlp_as_residual): | |
| def __init__(self, | |
| in_chans, | |
| out_chans, | |
| input_resolution=256, | |
| interp_mode='bilinear') -> None: | |
| super().__init__(in_chans, out_chans, input_resolution, interp_mode) | |
| """replace the first Rodin Conv 3D with ordinary rollout conv to save memory | |
| """ | |
| self.conv3D_0 = RodinRollOut_GroupConv_noConv3D(in_chans, out_chans) | |
| class SR3D(nn.Module): | |
| # https://github.com/SeanChenxy/Mimic3D/blob/77d313656df3cd5536d2c4c5766db3a56208eea6/training/networks_stylegan2.py#L629 | |
| # roll-out and apply two deconv/pixelUnshuffle layer | |
| def __init__(self, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| class RodinConv3D4X_lite_mlp_as_residual_improved(nn.Module): | |
| def __init__(self, | |
| in_chans, | |
| num_feat, | |
| out_chans, | |
| input_resolution=256) -> None: | |
| super().__init__() | |
| assert in_chans == 4 * out_chans | |
| assert num_feat == 2 * out_chans | |
| self.input_resolution = input_resolution | |
| # refer to https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/models/network_swinir.py#L750 | |
| self.upscale = 4 | |
| self.conv_after_body = RodinRollOutConv3D_GroupConv( | |
| in_chans, in_chans, 3, 1, 1) | |
| self.conv_before_upsample = nn.Sequential( | |
| RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1), | |
| nn.LeakyReLU(inplace=True)) | |
| self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, | |
| 1) | |
| if self.upscale == 4: | |
| self.conv_up2 = RodinRollOutConv3D_GroupConv( | |
| num_feat, num_feat, 3, 1, 1) | |
| self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, | |
| 1) | |
| self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3, | |
| 1, 1) | |
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
| def forward(self, x): | |
| # x: B 3 112*112 C | |
| B, C3, p, p = x.shape # after unpachify triplane | |
| C = C3 // 3 | |
| """following forward code copied from lite4x version | |
| """ | |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
| p) # B 3 C N -> B 3C h W | |
| # ? nearest or bilinear | |
| x = self.conv_after_body(x) + x | |
| x = self.conv_before_upsample(x) | |
| x = self.lrelu( | |
| self.conv_up1( | |
| torch.nn.functional.interpolate( | |
| x, | |
| scale_factor=2, | |
| mode='nearest', | |
| # align_corners=False, | |
| # antialias=True | |
| ))) | |
| if self.upscale == 4: | |
| x = self.lrelu( | |
| self.conv_up2( | |
| torch.nn.functional.interpolate( | |
| x, | |
| scale_factor=2, | |
| mode='nearest', | |
| # align_corners=False, | |
| # antialias=True | |
| ))) | |
| x = self.conv_last(self.lrelu(self.conv_hr(x))) | |
| assert x.shape[-1] == self.input_resolution | |
| return x | |
| class RodinConv3D4X_lite_improved_lint_withresidual(nn.Module): | |
| def __init__(self, | |
| in_chans, | |
| num_feat, | |
| out_chans, | |
| input_resolution=256) -> None: | |
| super().__init__() | |
| assert in_chans == 4 * out_chans | |
| assert num_feat == 2 * out_chans | |
| self.input_resolution = input_resolution | |
| # refer to https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/models/network_swinir.py#L750 | |
| self.upscale = 4 | |
| self.conv_after_body = RodinRollOutConv3D_GroupConv( | |
| in_chans, in_chans, 3, 1, 1) | |
| self.conv_before_upsample = nn.Sequential( | |
| RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1), | |
| nn.LeakyReLU(inplace=True)) | |
| self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, | |
| 1) | |
| if self.upscale == 4: | |
| self.conv_up2 = RodinRollOutConv3D_GroupConv( | |
| num_feat, num_feat, 3, 1, 1) | |
| self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, | |
| 1) | |
| self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3, | |
| 1, 1) | |
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
| def forward(self, x): | |
| # x: B 3 112*112 C | |
| B, C3, p, p = x.shape # after unpachify triplane | |
| C = C3 // 3 | |
| """following forward code copied from lite4x version | |
| """ | |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
| p) # B 3 C N -> B 3C h W | |
| # ? nearest or bilinear | |
| x = self.conv_after_body(x) + x | |
| x = self.conv_before_upsample(x) | |
| x = self.lrelu( | |
| self.conv_up1( | |
| torch.nn.functional.interpolate( | |
| x, | |
| scale_factor=2, | |
| mode='nearest', | |
| # align_corners=False, | |
| # antialias=True | |
| ))) | |
| if self.upscale == 4: | |
| x = self.lrelu( | |
| self.conv_up2( | |
| torch.nn.functional.interpolate( | |
| x, | |
| scale_factor=2, | |
| mode='nearest', | |
| # align_corners=False, | |
| # antialias=True | |
| ))) | |
| x = self.conv_last(self.lrelu(self.conv_hr(x) + x)) | |
| assert x.shape[-1] == self.input_resolution | |
| return x | |
| class RodinRollOutConv3DSR_FlexibleChannels(nn.Module): | |
| def __init__(self, | |
| in_chans, | |
| num_out_ch=96, | |
| input_resolution=256, | |
| **kwargs) -> None: | |
| super().__init__() | |
| self.block0 = RodinConv3D_SynthesisLayer(in_chans, | |
| num_out_ch) # in_chans=48 | |
| self.block1 = RodinConv3D_SynthesisLayer(num_out_ch, num_out_ch) | |
| self.input_resolution = input_resolution # 64 -> 256 SR | |
| def forward(self, x): | |
| # x: B 3 112*112 C | |
| B, C3, p, p = x.shape # after unpachify triplane | |
| C = C3 // 3 | |
| # group_size = C3 // C | |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
| p) # B 3 C N -> B 3C h W | |
| if x.shape[-1] != self.input_resolution: | |
| x = torch.nn.functional.interpolate(x, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=True) | |
| x = self.block0(x) | |
| x = self.block1(x) | |
| return x | |
| # previous worked version | |
| class RodinRollOutConv3DSR4X(nn.Module): | |
| # follow PixelUnshuffleUpsample | |
| def __init__(self, in_chans, **kwargs) -> None: | |
| super().__init__() | |
| # self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96 * 2) # TODO, match the old behaviour now. | |
| # self.block1 = RodinConv3D_SynthesisLayer(96 * 2, 96) | |
| self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96) | |
| self.block1 = RodinConv3D_SynthesisLayer( | |
| 96, 96) # baseline choice, validate with no LPIPS loss here | |
| self.input_resolution = 64 # 64 -> 256 | |
| def forward(self, x): | |
| # x: B 3 112*112 C | |
| B, C3, p, p = x.shape # after unpachify triplane | |
| C = C3 // 3 | |
| # group_size = C3 // C | |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
| p) # B 3 C N -> B 3C h W | |
| if x.shape[-1] != self.input_resolution: | |
| x = torch.nn.functional.interpolate(x, | |
| size=(self.input_resolution, | |
| self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=True) | |
| x = self.block0(x) | |
| x = self.block1(x) | |
| return x | |
| class Upsample3D(nn.Module): | |
| """Upsample module. | |
| Args: | |
| scale (int): Scale factor. Supported scales: 2^n and 3. | |
| num_feat (int): Channel number of intermediate features. | |
| """ | |
| def __init__(self, scale, num_feat): | |
| super().__init__() | |
| m_convs = [] | |
| m_pixelshuffle = [] | |
| assert (scale & (scale - 1)) == 0, 'scale = 2^n' | |
| self.scale = scale | |
| for _ in range(int(math.log(scale, 2))): | |
| m_convs.append( | |
| RodinRollOutConv3D_GroupConv(num_feat, 4 * num_feat, 3, 1, 1)) | |
| m_pixelshuffle.append(nn.PixelShuffle(2)) | |
| self.m_convs = nn.ModuleList(m_convs) | |
| self.m_pixelshuffle = nn.ModuleList(m_pixelshuffle) | |
| # @torch.autocast(device_type='cuda') | |
| def forward(self, x): | |
| for scale_idx in range(int(math.log(self.scale, 2))): | |
| x = self.m_convs[scale_idx](x) # B 3C H W | |
| # x = | |
| # B, C3, H, W = x.shape | |
| x = x.reshape(x.shape[0] * 3, x.shape[1] // 3, *x.shape[2:]) | |
| x = self.m_pixelshuffle[scale_idx](x) | |
| x = x.reshape(x.shape[0] // 3, x.shape[1] * 3, *x.shape[2:]) | |
| return x | |
| class RodinConv3DPixelUnshuffleUpsample(nn.Module): | |
| def __init__(self, | |
| output_dim, | |
| num_feat=32 * 6, | |
| num_out_ch=32 * 3, | |
| sr_ratio=4, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__() | |
| self.conv_after_body = RodinRollOutConv3D_GroupConv( | |
| output_dim, output_dim, 3, 1, 1) | |
| self.conv_before_upsample = nn.Sequential( | |
| RodinRollOutConv3D_GroupConv(output_dim, num_feat, 3, 1, 1), | |
| nn.LeakyReLU(inplace=True)) | |
| self.upsample = Upsample3D(sr_ratio, num_feat) # 4 time SR | |
| self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, num_out_ch, 3, | |
| 1, 1) | |
| # @torch.autocast(device_type='cuda') | |
| def forward(self, x, input_skip_connection=True, *args, **kwargs): | |
| # x = self.conv_first(x) | |
| if input_skip_connection: | |
| x = self.conv_after_body(x) + x | |
| else: | |
| x = self.conv_after_body(x) | |
| x = self.conv_before_upsample(x) | |
| x = self.upsample(x) | |
| x = self.conv_last(x) | |
| return x | |
| class RodinConv3DPixelUnshuffleUpsample_improvedVersion(nn.Module): | |
| def __init__( | |
| self, | |
| output_dim, | |
| num_out_ch=32 * 3, | |
| sr_ratio=4, | |
| input_resolution=256, | |
| ) -> None: | |
| super().__init__() | |
| self.input_resolution = input_resolution | |
| # self.conv_first = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, | |
| # 3, 1, 1) | |
| self.upsample = Upsample3D(sr_ratio, output_dim) # 4 time SR | |
| self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, | |
| 3, 1, 1) | |
| def forward(self, x, bilinear_upsample=True): | |
| B, C3, p, p = x.shape # after unpachify triplane | |
| C = C3 // 3 | |
| group_size = C3 // C | |
| assert group_size == 3, 'designed for triplane here' | |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
| p) # B 3 C N -> B 3C h W | |
| if bilinear_upsample and x.shape[-1] != self.input_resolution: | |
| x_bilinear_upsample = torch.nn.functional.interpolate( | |
| x, | |
| size=(self.input_resolution, self.input_resolution), | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=True) | |
| x = self.upsample(x) + x_bilinear_upsample | |
| else: | |
| # x_bilinear_upsample = x | |
| x = self.upsample(x) | |
| x = self.conv_last(x) | |
| return x | |
| class RodinConv3DPixelUnshuffleUpsample_improvedVersion2(nn.Module): | |
| """removed nearest neighbour residual conenctions, add a conv layer residual conenction | |
| """ | |
| def __init__( | |
| self, | |
| output_dim, | |
| num_out_ch=32 * 3, | |
| sr_ratio=4, | |
| input_resolution=256, | |
| ) -> None: | |
| super().__init__() | |
| self.input_resolution = input_resolution | |
| self.conv_after_body = RodinRollOutConv3D_GroupConv( | |
| output_dim, num_out_ch, 3, 1, 1) | |
| self.upsample = Upsample3D(sr_ratio, output_dim) # 4 time SR | |
| self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, | |
| 3, 1, 1) | |
| def forward(self, x, input_skip_connection=True): | |
| B, C3, p, p = x.shape # after unpachify triplane | |
| C = C3 // 3 | |
| group_size = C3 // C | |
| assert group_size == 3, 'designed for triplane here' | |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, | |
| p) # B 3 C N -> B 3C h W | |
| if input_skip_connection: | |
| x = self.conv_after_body(x) + x | |
| else: | |
| x = self.conv_after_body(x) | |
| x = self.upsample(x) | |
| x = self.conv_last(x) | |
| return x | |
| class CLSCrossAttentionBlock(nn.Module): | |
| def __init__(self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4., | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop=0., | |
| attn_drop=0., | |
| drop_path=0., | |
| act_layer=nn.GELU, | |
| norm_layer=nn.LayerNorm, | |
| has_mlp=False): | |
| super().__init__() | |
| self.norm1 = norm_layer(dim) | |
| self.attn = CrossAttention(dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=drop) | |
| # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
| self.drop_path = DropPath( | |
| drop_path) if drop_path > 0. else nn.Identity() | |
| self.has_mlp = has_mlp | |
| if has_mlp: | |
| self.norm2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = Mlp(in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop) | |
| def forward(self, x): | |
| x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) | |
| if self.has_mlp: | |
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |
| return x | |
| class Conv3DCrossAttentionBlock(nn.Module): | |
| def __init__(self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4., | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop=0., | |
| attn_drop=0., | |
| drop_path=0., | |
| act_layer=nn.GELU, | |
| norm_layer=nn.LayerNorm, | |
| has_mlp=False): | |
| super().__init__() | |
| self.norm1 = norm_layer(dim) | |
| self.attn = Conv3D_Aware_CrossAttention(dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=drop) | |
| # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
| self.drop_path = DropPath( | |
| drop_path) if drop_path > 0. else nn.Identity() | |
| self.has_mlp = has_mlp | |
| if has_mlp: | |
| self.norm2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = Mlp(in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop) | |
| def forward(self, x): | |
| x = x + self.drop_path(self.attn(self.norm1(x))) | |
| if self.has_mlp: | |
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |
| return x | |
| class Conv3DCrossAttentionBlockXformerMHA(Conv3DCrossAttentionBlock): | |
| def __init__(self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop=0, | |
| attn_drop=0, | |
| drop_path=0, | |
| act_layer=nn.GELU, | |
| norm_layer=nn.LayerNorm, | |
| has_mlp=False): | |
| super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, | |
| attn_drop, drop_path, act_layer, norm_layer, has_mlp) | |
| # self.attn = xformer_Conv3D_Aware_CrossAttention(dim, | |
| self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=drop) | |
| class Conv3DCrossAttentionBlockXformerMHANested( | |
| Conv3DCrossAttentionBlockXformerMHA): | |
| def __init__(self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop=0., | |
| attn_drop=0., | |
| drop_path=0., | |
| act_layer=nn.GELU, | |
| norm_layer=nn.LayerNorm, | |
| has_mlp=False): | |
| super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, | |
| attn_drop, drop_path, act_layer, norm_layer, has_mlp) | |
| """for in-place replaing the internal attn in Dino ViT. | |
| """ | |
| def forward(self, x): | |
| Bx3, N, C = x.shape | |
| B, group_size = Bx3 // 3, 3 | |
| x = x.reshape(B, group_size, N, C) # in plane vit | |
| x = super().forward(x) | |
| return x.reshape(B * group_size, N, | |
| C) # to match the original attn size | |
| class Conv3DCrossAttentionBlockXformerMHANested_withinC( | |
| Conv3DCrossAttentionBlockXformerMHANested): | |
| def __init__(self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop=0, | |
| attn_drop=0, | |
| drop_path=0, | |
| act_layer=nn.GELU, | |
| norm_layer=nn.LayerNorm, | |
| has_mlp=False): | |
| super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, | |
| attn_drop, drop_path, act_layer, norm_layer, has_mlp) | |
| self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=drop) | |
| def forward(self, x): | |
| # basic TX attention forward function | |
| x = x + self.drop_path(self.attn(self.norm1(x))) | |
| if self.has_mlp: | |
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |
| return x | |
| class TriplaneFusionBlock(nn.Module): | |
| """4 ViT blocks + 1 CrossAttentionBlock | |
| """ | |
| def __init__(self, | |
| vit_blks, | |
| num_heads, | |
| embed_dim, | |
| use_fusion_blk=True, | |
| cross_attention_blk=CLSCrossAttentionBlock, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| self.num_branches = 3 # triplane | |
| self.vit_blks = vit_blks | |
| if use_fusion_blk: | |
| self.fusion = nn.ModuleList() | |
| # copied vit settings from https://github.dev/facebookresearch/dinov2 | |
| nh = num_heads | |
| dim = embed_dim | |
| mlp_ratio = 4 # defined for all dino2 model | |
| qkv_bias = True | |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
| drop_path_rate = 0.3 # default setting | |
| attn_drop = proj_drop = 0.0 | |
| qk_scale = None # TODO, double check | |
| for d in range(self.num_branches): | |
| self.fusion.append( | |
| cross_attention_blk( | |
| dim=dim, | |
| num_heads=nh, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| # drop=drop, | |
| drop=proj_drop, | |
| attn_drop=attn_drop, | |
| drop_path=drop_path_rate, | |
| norm_layer=norm_layer, # type: ignore | |
| has_mlp=False)) | |
| else: | |
| self.fusion = None | |
| def forward(self, x): | |
| # modified from https://github.com/IBM/CrossViT/blob/main/models/crossvit.py#L132 | |
| """x: B 3 N C, where N = H*W tokens | |
| """ | |
| # self attention, by merging the triplane channel into B for parallel computation | |
| # ! move the below to the front of the first call | |
| B, group_size, N, C = x.shape # has [cls] token in N | |
| assert group_size == 3, 'triplane' | |
| x = x.view(B * group_size, N, C) | |
| for blk in self.vit_blks: | |
| x = blk(x) # B 3 N C | |
| if self.fusion is None: | |
| return x.view(B, group_size, N, C) | |
| # outs_b = x.view(B, group_size, N, | |
| # C).chunk(chunks=3, | |
| # dim=1) # 3 * [B, 1, N//3, C] Tensors, for fusion | |
| outs_b = x.chunk(chunks=3, | |
| dim=0) # 3 * [B, N//3, C] Tensors, for fusion | |
| # only take the cls token out | |
| proj_cls_token = [x[:, 0:1] for x in outs_b] | |
| # cross attention | |
| outs = [] | |
| for i in range(self.num_branches): | |
| tmp = torch.cat( | |
| (proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, | |
| ...]), | |
| dim=1) | |
| tmp = self.fusion[i](tmp) | |
| # reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...]) | |
| reverted_proj_cls_token = tmp[:, 0:1, ...] | |
| tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), | |
| dim=1) | |
| outs.append(tmp) | |
| # outs = ? needs to merge back? | |
| outs = torch.stack(outs, 1) # B 3 N C | |
| return outs | |
| class TriplaneFusionBlockv2(nn.Module): | |
| """4 ViT blocks + 1 CrossAttentionBlock | |
| """ | |
| def __init__(self, | |
| vit_blks, | |
| num_heads, | |
| embed_dim, | |
| use_fusion_blk=True, | |
| fusion_ca_blk=Conv3DCrossAttentionBlock, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| self.num_branches = 3 # triplane | |
| self.vit_blks = vit_blks | |
| if use_fusion_blk: | |
| # self.fusion = nn.ModuleList() | |
| # copied vit settings from https://github.dev/facebookresearch/dinov2 | |
| nh = num_heads | |
| dim = embed_dim | |
| mlp_ratio = 4 # defined for all dino2 model | |
| qkv_bias = True | |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
| drop_path_rate = 0.3 # default setting | |
| attn_drop = proj_drop = 0.0 | |
| qk_scale = None # TODO, double check | |
| # for d in range(self.num_branches): | |
| self.fusion = fusion_ca_blk( # one fusion is enough | |
| dim=dim, | |
| num_heads=nh, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| # drop=drop, | |
| drop=proj_drop, | |
| attn_drop=attn_drop, | |
| drop_path=drop_path_rate, | |
| norm_layer=norm_layer, # type: ignore | |
| has_mlp=False) | |
| else: | |
| self.fusion = None | |
| def forward(self, x): | |
| # modified from https://github.com/IBM/CrossViT/blob/main/models/crossvit.py#L132 | |
| """x: B 3 N C, where N = H*W tokens | |
| """ | |
| # self attention, by merging the triplane channel into B for parallel computation | |
| # ! move the below to the front of the first call | |
| B, group_size, N, C = x.shape # has [cls] token in N | |
| assert group_size == 3, 'triplane' | |
| x = x.reshape(B * group_size, N, C) | |
| for blk in self.vit_blks: | |
| x = blk(x) # B 3 N C | |
| if self.fusion is None: | |
| return x.reshape(B, group_size, N, C) | |
| x = x.reshape(B, group_size, N, C) # .chunk(chunks=3, | |
| # dim=1) # 3 * [B, N//3, C] Tensors, for fusion | |
| return self.fusion(x) | |
| class TriplaneFusionBlockv3(TriplaneFusionBlockv2): | |
| def __init__(self, | |
| vit_blks, | |
| num_heads, | |
| embed_dim, | |
| use_fusion_blk=True, | |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, | |
| fusion_ca_blk, *args, **kwargs) | |
| class TriplaneFusionBlockv4(TriplaneFusionBlockv3): | |
| def __init__(self, | |
| vit_blks, | |
| num_heads, | |
| embed_dim, | |
| use_fusion_blk=True, | |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, | |
| fusion_ca_blk, *args, **kwargs) | |
| """OOM? directly replace the atten here | |
| """ | |
| assert len(vit_blks) == 2 | |
| # del self.vit_blks[1].attn | |
| del self.vit_blks[1].attn, self.vit_blks[1].ls1, self.vit_blks[1].norm1 | |
| def ffn_residual_func(self, tx_blk, x: Tensor) -> Tensor: | |
| return tx_blk.ls2( | |
| tx_blk.mlp(tx_blk.norm2(x)) | |
| ) # https://github.com/facebookresearch/dinov2/blob/c3c2683a13cde94d4d99f523cf4170384b00c34c/dinov2/layers/block.py#L86C1-L87C53 | |
| def forward(self, x): | |
| """x: B 3 N C, where N = H*W tokens | |
| """ | |
| assert self.fusion is not None | |
| B, group_size, N, C = x.shape # has [cls] token in N | |
| x = x.reshape(B * group_size, N, C) # in plane vit | |
| # in plane self attention | |
| x = self.vit_blks[0](x) | |
| # 3D cross attention blk + ffn | |
| x = x + self.fusion(x.reshape(B, group_size, N, C)).reshape( | |
| B * group_size, N, C) | |
| x = x + self.ffn_residual_func(self.vit_blks[1], x) | |
| return x.reshape(B, group_size, N, C) | |
| class TriplaneFusionBlockv4_nested(nn.Module): | |
| def __init__(self, | |
| vit_blks, | |
| num_heads, | |
| embed_dim, | |
| use_fusion_blk=True, | |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__() | |
| self.num_branches = 3 # triplane | |
| self.vit_blks = vit_blks | |
| assert use_fusion_blk | |
| assert len(vit_blks) == 2 | |
| # ! replace vit_blks[1] attn layer with 3D aware attention | |
| del self.vit_blks[ | |
| 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1 | |
| # copied vit settings from https://github.dev/facebookresearch/dinov2 | |
| nh = num_heads | |
| dim = embed_dim | |
| mlp_ratio = 4 # defined for all dino2 model | |
| qkv_bias = True | |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
| drop_path_rate = 0.3 # default setting | |
| attn_drop = proj_drop = 0.0 | |
| qk_scale = None # TODO, double check | |
| self.vit_blks[1].attn = fusion_ca_blk( # one fusion is enough | |
| dim=dim, | |
| num_heads=nh, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| # drop=drop, | |
| drop=proj_drop, | |
| attn_drop=attn_drop, | |
| drop_path=drop_path_rate, | |
| norm_layer=norm_layer, # type: ignore | |
| has_mlp=False) | |
| def forward(self, x): | |
| """x: B 3 N C, where N = H*W tokens | |
| """ | |
| # self attention, by merging the triplane channel into B for parallel computation | |
| # ! move the below to the front of the first call | |
| B, group_size, N, C = x.shape # has [cls] token in N | |
| assert group_size == 3, 'triplane' | |
| x = x.reshape(B * group_size, N, C) | |
| for blk in self.vit_blks: | |
| x = blk(x) # B 3 N C | |
| # TODO, avoid the reshape overhead? | |
| return x.reshape(B, group_size, N, C) | |
| class TriplaneFusionBlockv4_nested_init_from_dino(nn.Module): | |
| def __init__(self, | |
| vit_blks, | |
| num_heads, | |
| embed_dim, | |
| use_fusion_blk=True, | |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, | |
| init_from_dino=True, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__() | |
| self.num_branches = 3 # triplane | |
| self.vit_blks = vit_blks | |
| assert use_fusion_blk | |
| assert len(vit_blks) == 2 | |
| # copied vit settings from https://github.dev/facebookresearch/dinov2 | |
| nh = num_heads | |
| dim = embed_dim | |
| mlp_ratio = 4 # defined for all dino2 model | |
| qkv_bias = True | |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
| drop_path_rate = 0.3 # default setting | |
| attn_drop = proj_drop = 0.0 | |
| qk_scale = None # TODO, double check | |
| attn_3d = fusion_ca_blk( # one fusion is enough | |
| dim=dim, | |
| num_heads=nh, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| # drop=drop, | |
| drop=proj_drop, | |
| attn_drop=attn_drop, | |
| drop_path=drop_path_rate, | |
| norm_layer=norm_layer, # type: ignore | |
| has_mlp=False) | |
| # ! initialize 3dattn from dino attn | |
| if init_from_dino: | |
| merged_qkv_linear = self.vit_blks[1].attn.qkv | |
| attn_3d.attn.proj.load_state_dict( | |
| self.vit_blks[1].attn.proj.state_dict()) | |
| # Initialize the Q, K, and V linear layers using the weights of the merged QKV linear layer | |
| attn_3d.attn.wq.weight.data = merged_qkv_linear.weight.data[: | |
| dim, :] | |
| attn_3d.attn.w_kv.weight.data = merged_qkv_linear.weight.data[ | |
| dim:, :] | |
| # Optionally, you can initialize the biases as well (if your QKV linear layer has biases) | |
| if qkv_bias: | |
| attn_3d.attn.wq.bias.data = merged_qkv_linear.bias.data[:dim] | |
| attn_3d.attn.w_kv.bias.data = merged_qkv_linear.bias.data[dim:] | |
| del self.vit_blks[1].attn | |
| # ! assign | |
| self.vit_blks[1].attn = attn_3d | |
| def forward(self, x): | |
| """x: B 3 N C, where N = H*W tokens | |
| """ | |
| # self attention, by merging the triplane channel into B for parallel computation | |
| # ! move the below to the front of the first call | |
| B, group_size, N, C = x.shape # has [cls] token in N | |
| assert group_size == 3, 'triplane' | |
| x = x.reshape(B * group_size, N, C) | |
| for blk in self.vit_blks: | |
| x = blk(x) # B 3 N C | |
| # TODO, avoid the reshape overhead? | |
| return x.reshape(B, group_size, N, C) | |
| class TriplaneFusionBlockv4_nested_init_from_dino_lite(nn.Module): | |
| def __init__(self, | |
| vit_blks, | |
| num_heads, | |
| embed_dim, | |
| use_fusion_blk=True, | |
| fusion_ca_blk=None, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__() | |
| self.num_branches = 3 # triplane | |
| self.vit_blks = vit_blks | |
| assert use_fusion_blk | |
| assert len(vit_blks) == 2 | |
| # copied vit settings from https://github.dev/facebookresearch/dinov2 | |
| nh = num_heads | |
| dim = embed_dim | |
| mlp_ratio = 4 # defined for all dino2 model | |
| qkv_bias = True | |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
| drop_path_rate = 0.3 # default setting | |
| attn_drop = proj_drop = 0.0 | |
| qk_scale = None # TODO, double check | |
| attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( # ! raw 3D attn layer | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop) | |
| del self.vit_blks[1].attn | |
| # ! assign | |
| self.vit_blks[1].attn = attn_3d | |
| def forward(self, x): | |
| """x: B N C, where N = H*W tokens. Just raw ViT forward pass | |
| """ | |
| # ! move the below to the front of the first call | |
| B, N, C = x.shape # has [cls] token in N | |
| for blk in self.vit_blks: | |
| x = blk(x) # B N C | |
| return x | |
| class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge(nn.Module): | |
| def __init__(self, | |
| vit_blks, | |
| num_heads, | |
| embed_dim, | |
| use_fusion_blk=True, | |
| fusion_ca_blk=None, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__() | |
| self.vit_blks = vit_blks | |
| assert use_fusion_blk | |
| assert len(vit_blks) == 2 | |
| # copied vit settings from https://github.dev/facebookresearch/dinov2 | |
| nh = num_heads | |
| dim = embed_dim | |
| qkv_bias = True | |
| attn_drop = proj_drop = 0.0 | |
| qk_scale = None # TODO, double check | |
| if False: # abla | |
| for blk in self.vit_blks: | |
| attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( # ! raw 3D attn layer | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop) | |
| blk.attn = self_cross_attn(blk.attn, attn_3d) | |
| def forward(self, x): | |
| """x: B N C, where N = H*W tokens. Just raw ViT forward pass | |
| """ | |
| # ! move the below to the front of the first call | |
| B, N, C = x.shape # has [cls] token in N | |
| for blk in self.vit_blks: | |
| x = blk(x) # B N C | |
| return x | |
| class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge): | |
| # on roll out + B 3L C | |
| def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: | |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) | |
| def forward(self, x): | |
| """x: B 3 N C, where N = H*W tokens | |
| """ | |
| # ! move the below to the front of the first call | |
| # B, N, C = x.shape # has [cls] token in N | |
| B, group_size, N, C = x.shape # has [cls] token in N | |
| x = x.reshape(B, group_size*N, C) | |
| for blk in self.vit_blks: | |
| x = blk(x) # B N C | |
| x = x.reshape(B, group_size, N, C) # outer loop tradition | |
| return x | |
| class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge): | |
| # roll out + B 3L C | |
| def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: | |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) | |
| def forward(self, x): | |
| """x: B 3 N C, where N = H*W tokens | |
| """ | |
| # ! move the below to the front of the first call | |
| # B, N, C = x.shape # has [cls] token in N | |
| B, group_size, N, C = x.shape # has [cls] token in N | |
| x = x.reshape(B*group_size, N, C) | |
| x = self.vit_blks[0](x) | |
| x = x.reshape(B,group_size*N, C) | |
| x = self.vit_blks[1](x) | |
| x = x.reshape(B, group_size, N, C) # outer loop tradition | |
| return x | |
| class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_add3DAttn(TriplaneFusionBlockv4_nested_init_from_dino): | |
| # no roll out + 3D Attention | |
| def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: | |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) | |
| def forward(self, x): | |
| """x: B 3 N C, where N = H*W tokens | |
| """ | |
| B, group_size, N, C = x.shape # has [cls] token in N | |
| x = x.reshape(B, group_size*N, C) | |
| x = self.vit_blks[0](x) # B 3 L C | |
| # ! move the below to the front of the first call | |
| x = x.reshape(B, group_size, N, C).reshape(B*group_size, N, C) | |
| x = self.vit_blks[1](x) # has 3D attention | |
| return x.reshape(B, group_size, N, C) | |
| return x | |
| class TriplaneFusionBlockv5_ldm_addCA(nn.Module): | |
| def __init__(self, | |
| vit_blks, | |
| num_heads, | |
| embed_dim, | |
| use_fusion_blk=True, | |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__() | |
| self.num_branches = 3 # triplane | |
| self.vit_blks = vit_blks | |
| assert use_fusion_blk | |
| assert len(vit_blks) == 2 | |
| # ! rather than replacing, add a 3D attention block after. | |
| # del self.vit_blks[ | |
| # 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1 | |
| self.norm_for_atten_3d = deepcopy(self.vit_blks[1].norm1) | |
| # copied vit settings from https://github.dev/facebookresearch/dinov2 | |
| nh = num_heads | |
| dim = embed_dim | |
| mlp_ratio = 4 # defined for all dino2 model | |
| qkv_bias = True | |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
| drop_path_rate = 0.3 # default setting | |
| attn_drop = proj_drop = 0.0 | |
| qk_scale = None # TODO, double check | |
| self.attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop) | |
| def forward(self, x): | |
| """x: B 3 N C, where N = H*W tokens | |
| """ | |
| # self attention, by merging the triplane channel into B for parallel computation | |
| # ! move the below to the front of the first call | |
| B, group_size, N, C = x.shape # has [cls] token in N | |
| assert group_size == 3, 'triplane' | |
| flatten_token = lambda x: x.reshape(B * group_size, N, C) | |
| unflatten_token = lambda x: x.reshape(B, group_size, N, C) | |
| x = flatten_token(x) | |
| x = self.vit_blks[0](x) | |
| x = unflatten_token(x) | |
| x = self.attn_3d(self.norm_for_atten_3d(x)) + x | |
| x = flatten_token(x) | |
| x = self.vit_blks[1](x) | |
| return unflatten_token(x) | |
| class TriplaneFusionBlockv6_ldm_addCA_Init3DAttnfrom2D( | |
| TriplaneFusionBlockv5_ldm_addCA): | |
| def __init__(self, | |
| vit_blks, | |
| num_heads, | |
| embed_dim, | |
| use_fusion_blk=True, | |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, | |
| fusion_ca_blk, *args, **kwargs) | |
| def forward(self, x): | |
| """x: B 3 N C, where N = H*W tokens | |
| """ | |
| # self attention, by merging the triplane channel into B for parallel computation | |
| # ! move the below to the front of the first call | |
| B, group_size, N, C = x.shape # has [cls] token in N | |
| assert group_size == 3, 'triplane' | |
| flatten_token = lambda x: x.reshape(B * group_size, N, C) | |
| unflatten_token = lambda x: x.reshape(B, group_size, N, C) | |
| x = flatten_token(x) | |
| x = self.vit_blks[0](x) | |
| x = unflatten_token(x) | |
| x = self.attn_3d(self.norm_for_atten_3d(x)) + x | |
| x = flatten_token(x) | |
| x = self.vit_blks[1](x) | |
| return unflatten_token(x) | |
| class TriplaneFusionBlockv5_ldm_add_dualCA(nn.Module): | |
| def __init__(self, | |
| vit_blks, | |
| num_heads, | |
| embed_dim, | |
| use_fusion_blk=True, | |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, | |
| *args, | |
| **kwargs) -> None: | |
| super().__init__() | |
| self.num_branches = 3 # triplane | |
| self.vit_blks = vit_blks | |
| assert use_fusion_blk | |
| assert len(vit_blks) == 2 | |
| # ! rather than replacing, add a 3D attention block after. | |
| # del self.vit_blks[ | |
| # 1].attn # , self.vit_blks[1].ls1, self.vit_blks[1].norm1 | |
| self.norm_for_atten_3d_0 = deepcopy(self.vit_blks[0].norm1) | |
| self.norm_for_atten_3d_1 = deepcopy(self.vit_blks[1].norm1) | |
| # copied vit settings from https://github.dev/facebookresearch/dinov2 | |
| nh = num_heads | |
| dim = embed_dim | |
| mlp_ratio = 4 # defined for all dino2 model | |
| qkv_bias = True | |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
| drop_path_rate = 0.3 # default setting | |
| attn_drop = proj_drop = 0.0 | |
| qk_scale = None # TODO, double check | |
| self.attn_3d_0 = xformer_Conv3D_Aware_CrossAttention_xygrid( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop) | |
| self.attn_3d_1 = deepcopy(self.attn_3d_0) | |
| def forward(self, x): | |
| """x: B 3 N C, where N = H*W tokens | |
| """ | |
| # self attention, by merging the triplane channel into B for parallel computation | |
| # ! move the below to the front of the first call | |
| B, group_size, N, C = x.shape # has [cls] token in N | |
| assert group_size == 3, 'triplane' | |
| flatten_token = lambda x: x.reshape(B * group_size, N, C) | |
| unflatten_token = lambda x: x.reshape(B, group_size, N, C) | |
| x = flatten_token(x) | |
| x = self.vit_blks[0](x) | |
| x = unflatten_token(x) | |
| x = self.attn_3d_0(self.norm_for_atten_3d_0(x)) + x | |
| x = flatten_token(x) | |
| x = self.vit_blks[1](x) | |
| x = unflatten_token(x) | |
| x = self.attn_3d_1(self.norm_for_atten_3d_1(x)) + x | |
| return unflatten_token(x) | |
| def drop_path(x, drop_prob: float = 0., training: bool = False): | |
| if drop_prob == 0. or not training: | |
| return x | |
| keep_prob = 1 - drop_prob | |
| shape = (x.shape[0], ) + (1, ) * ( | |
| x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |
| random_tensor = keep_prob + torch.rand( | |
| shape, dtype=x.dtype, device=x.device) | |
| random_tensor.floor_() # binarize | |
| output = x.div(keep_prob) * random_tensor | |
| return output | |
| class DropPath(nn.Module): | |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
| """ | |
| def __init__(self, drop_prob=None): | |
| super(DropPath, self).__init__() | |
| self.drop_prob = drop_prob | |
| def forward(self, x): | |
| return drop_path(x, self.drop_prob, self.training) | |
| class Mlp(nn.Module): | |
| def __init__(self, | |
| in_features, | |
| hidden_features=None, | |
| out_features=None, | |
| act_layer=nn.GELU, | |
| drop=0.): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.fc1 = nn.Linear(in_features, hidden_features) | |
| self.act = act_layer() | |
| self.fc2 = nn.Linear(hidden_features, out_features) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.drop(x) | |
| return x | |
| class Block(nn.Module): | |
| def __init__(self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4., | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop=0., | |
| attn_drop=0., | |
| drop_path=0., | |
| act_layer=nn.GELU, | |
| norm_layer=nn.LayerNorm): | |
| super().__init__() | |
| self.norm1 = norm_layer(dim) | |
| # self.attn = Attention(dim, | |
| self.attn = MemEffAttention(dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=drop) | |
| self.drop_path = DropPath( | |
| drop_path) if drop_path > 0. else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = Mlp(in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop) | |
| def forward(self, x, return_attention=False): | |
| y, attn = self.attn(self.norm1(x)) | |
| if return_attention: | |
| return attn | |
| x = x + self.drop_path(y) | |
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |
| return x | |
| class PatchEmbed(nn.Module): | |
| """ Image to Patch Embedding | |
| """ | |
| def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): | |
| super().__init__() | |
| num_patches = (img_size // patch_size) * (img_size // patch_size) | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.num_patches = num_patches | |
| self.proj = nn.Conv2d(in_chans, | |
| embed_dim, | |
| kernel_size=patch_size, | |
| stride=patch_size) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| x = self.proj(x).flatten(2).transpose(1, 2) # B, C, L -> B, L, C | |
| return x | |
| class VisionTransformer(nn.Module): | |
| """ Vision Transformer """ | |
| def __init__(self, | |
| img_size=[224], | |
| patch_size=16, | |
| in_chans=3, | |
| num_classes=0, | |
| embed_dim=768, | |
| depth=12, | |
| num_heads=12, | |
| mlp_ratio=4., | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop_rate=0., | |
| attn_drop_rate=0., | |
| drop_path_rate=0., | |
| norm_layer='nn.LayerNorm', | |
| patch_embedding=True, | |
| cls_token=True, | |
| pixel_unshuffle=False, | |
| **kwargs): | |
| super().__init__() | |
| self.num_features = self.embed_dim = embed_dim | |
| self.patch_size = patch_size | |
| # if norm_layer == 'nn.LayerNorm': | |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
| if patch_embedding: | |
| self.patch_embed = PatchEmbed(img_size=img_size[0], | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| embed_dim=embed_dim) | |
| num_patches = self.patch_embed.num_patches | |
| self.img_size = self.patch_embed.img_size | |
| else: | |
| self.patch_embed = None | |
| self.img_size = img_size[0] | |
| num_patches = (img_size[0] // patch_size) * (img_size[0] // | |
| patch_size) | |
| if cls_token: | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.pos_embed = nn.Parameter( | |
| torch.zeros(1, num_patches + 1, embed_dim)) | |
| else: | |
| self.cls_token = None | |
| self.pos_embed = nn.Parameter( | |
| torch.zeros(1, num_patches, embed_dim)) | |
| self.pos_drop = nn.Dropout(p=drop_rate) | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) | |
| ] # stochastic depth decay rule | |
| self.blocks = nn.ModuleList([ | |
| Block(dim=embed_dim, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| drop=drop_rate, | |
| attn_drop=attn_drop_rate, | |
| drop_path=dpr[i], | |
| norm_layer=norm_layer) for i in range(depth) | |
| ]) | |
| self.norm = norm_layer(embed_dim) | |
| # Classifier head | |
| self.head = nn.Linear( | |
| embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
| trunc_normal_(self.pos_embed, std=.02) | |
| if cls_token: | |
| trunc_normal_(self.cls_token, std=.02) | |
| self.apply(self._init_weights) | |
| # if pixel_unshuffle: | |
| # self.decoder_pred = nn.Linear(embed_dim, | |
| # patch_size**2 * out_chans, | |
| # bias=True) # decoder to patch | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def interpolate_pos_encoding(self, x, w, h): | |
| npatch = x.shape[1] - 1 | |
| N = self.pos_embed.shape[1] - 1 | |
| if npatch == N and w == h: | |
| return self.pos_embed | |
| patch_pos_embed = self.pos_embed[:, 1:] | |
| dim = x.shape[-1] | |
| w0 = w // self.patch_size | |
| h0 = h // self.patch_size | |
| # we add a small number to avoid floating point error in the interpolation | |
| # see discussion at https://github.com/facebookresearch/dino/issues/8 | |
| w0, h0 = w0 + 0.1, h0 + 0.1 | |
| patch_pos_embed = nn.functional.interpolate( | |
| patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), | |
| dim).permute(0, 3, 1, 2), | |
| scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), | |
| mode='bicubic', | |
| ) | |
| assert int(w0) == patch_pos_embed.shape[-2] and int( | |
| h0) == patch_pos_embed.shape[-1] | |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(2, -1, dim) | |
| if self.cls_token is not None: | |
| class_pos_embed = self.pos_embed[:, 0] | |
| return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), | |
| dim=1) | |
| return patch_pos_embed | |
| def prepare_tokens(self, x): | |
| B, nc, w, h = x.shape | |
| x = self.patch_embed(x) # patch linear embedding | |
| # add the [CLS] token to the embed patch tokens | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| # add positional encoding to each token | |
| x = x + self.interpolate_pos_encoding(x, w, h) | |
| return self.pos_drop(x) | |
| def forward(self, x): | |
| x = self.prepare_tokens(x) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| x = self.norm(x) | |
| return x[:, 1:] # return spatial feature maps, not the [CLS] token | |
| # return x[:, 0] | |
| def get_last_selfattention(self, x): | |
| x = self.prepare_tokens(x) | |
| for i, blk in enumerate(self.blocks): | |
| if i < len(self.blocks) - 1: | |
| x = blk(x) | |
| else: | |
| # return attention of the last block | |
| return blk(x, return_attention=True) | |
| def get_intermediate_layers(self, x, n=1): | |
| x = self.prepare_tokens(x) | |
| # we return the output tokens from the `n` last blocks | |
| output = [] | |
| for i, blk in enumerate(self.blocks): | |
| x = blk(x) | |
| if len(self.blocks) - i <= n: | |
| output.append(self.norm(x)) | |
| return output | |
| def vit_tiny(patch_size=16, **kwargs): | |
| model = VisionTransformer(patch_size=patch_size, | |
| embed_dim=192, | |
| depth=12, | |
| num_heads=3, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs) | |
| return model | |
| def vit_small(patch_size=16, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=patch_size, | |
| embed_dim=384, | |
| depth=12, | |
| num_heads=6, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore | |
| **kwargs) | |
| return model | |
| def vit_base(patch_size=16, **kwargs): | |
| model = VisionTransformer(patch_size=patch_size, | |
| embed_dim=768, | |
| depth=12, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs) | |
| return model | |
| vits = vit_small | |
| vitb = vit_base | |