|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
from common.cache import Cache |
|
|
|
|
|
from .attention.mmattn import NaSwinAttention |
|
|
from ..mm import MMArg |
|
|
from ..modulation import ada_layer_type |
|
|
from ..normalization import norm_layer_type |
|
|
from ..mm import MMArg, MMModule |
|
|
from ..mlp import get_mlp |
|
|
|
|
|
|
|
|
class NaMMSRTransformerBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
vid_dim: int, |
|
|
txt_dim: int, |
|
|
emb_dim: int, |
|
|
heads: int, |
|
|
head_dim: int, |
|
|
expand_ratio: int, |
|
|
norm: norm_layer_type, |
|
|
norm_eps: float, |
|
|
ada: ada_layer_type, |
|
|
qk_bias: bool, |
|
|
qk_norm: norm_layer_type, |
|
|
mlp_type: str, |
|
|
shared_weights: bool, |
|
|
rope_type: str, |
|
|
rope_dim: int, |
|
|
is_last_layer: bool, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
dim = MMArg(vid_dim, txt_dim) |
|
|
self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) |
|
|
|
|
|
self.attn = NaSwinAttention( |
|
|
vid_dim=vid_dim, |
|
|
txt_dim=txt_dim, |
|
|
heads=heads, |
|
|
head_dim=head_dim, |
|
|
qk_bias=qk_bias, |
|
|
qk_norm=qk_norm, |
|
|
qk_norm_eps=norm_eps, |
|
|
rope_type=rope_type, |
|
|
rope_dim=rope_dim, |
|
|
shared_weights=shared_weights, |
|
|
window=kwargs.pop("window", None), |
|
|
window_method=kwargs.pop("window_method", None), |
|
|
) |
|
|
|
|
|
self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) |
|
|
self.mlp = MMModule( |
|
|
get_mlp(mlp_type), |
|
|
dim=dim, |
|
|
expand_ratio=expand_ratio, |
|
|
shared_weights=shared_weights, |
|
|
vid_only=is_last_layer |
|
|
) |
|
|
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) |
|
|
self.is_last_layer = is_last_layer |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
vid: torch.FloatTensor, |
|
|
txt: torch.FloatTensor, |
|
|
vid_shape: torch.LongTensor, |
|
|
txt_shape: torch.LongTensor, |
|
|
emb: torch.FloatTensor, |
|
|
cache: Cache, |
|
|
) -> Tuple[ |
|
|
torch.FloatTensor, |
|
|
torch.FloatTensor, |
|
|
torch.LongTensor, |
|
|
torch.LongTensor, |
|
|
]: |
|
|
hid_len = MMArg( |
|
|
cache("vid_len", lambda: vid_shape.prod(-1)), |
|
|
cache("txt_len", lambda: txt_shape.prod(-1)), |
|
|
) |
|
|
ada_kwargs = { |
|
|
"emb": emb, |
|
|
"hid_len": hid_len, |
|
|
"cache": cache, |
|
|
"branch_tag": MMArg("vid", "txt"), |
|
|
} |
|
|
|
|
|
vid_attn, txt_attn = self.attn_norm(vid, txt) |
|
|
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) |
|
|
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) |
|
|
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) |
|
|
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) |
|
|
|
|
|
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) |
|
|
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) |
|
|
vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) |
|
|
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) |
|
|
vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) |
|
|
|
|
|
return vid_mlp, txt_mlp, vid_shape, txt_shape |
|
|
|