|
|
import torch |
|
|
import torch.nn as nn |
|
|
import copy |
|
|
from functools import partial |
|
|
from .dasheng import LayerScale, Attention, Mlp |
|
|
|
|
|
|
|
|
class Decoder_Block(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
num_heads, |
|
|
mlp_ratio=4., |
|
|
qkv_bias=False, |
|
|
drop=0., |
|
|
attn_drop=0., |
|
|
init_values=None, |
|
|
act_layer=nn.GELU, |
|
|
norm_layer=nn.LayerNorm, |
|
|
attention_type='Attention', |
|
|
fusion='adaln', |
|
|
): |
|
|
super().__init__() |
|
|
self.norm1 = norm_layer(dim) |
|
|
self.attn = Attention(dim, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=drop) |
|
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
|
|
|
|
|
self.norm2 = norm_layer(dim) |
|
|
self.mlp = Mlp(in_features=dim, |
|
|
hidden_features=int(dim * mlp_ratio), |
|
|
act_layer=act_layer, |
|
|
drop=drop) |
|
|
self.ls2 = LayerScale( |
|
|
dim, init_values=init_values) if init_values else nn.Identity() |
|
|
|
|
|
self.fusion = fusion |
|
|
if fusion == 'adaln': |
|
|
self.adaln = nn.Linear(dim, 6 * dim, bias=True) |
|
|
|
|
|
def forward(self, x, c=None): |
|
|
B, T, C = x.shape |
|
|
|
|
|
if self.fusion == 'adaln': |
|
|
ada = self.adaln(c) |
|
|
(scale_msa, gate_msa, shift_msa, |
|
|
scale_mlp, gate_mlp, shift_mlp) = ada.reshape(B, 6, -1).chunk(6, dim=1) |
|
|
|
|
|
x_norm = self.norm1(x) * (1 + scale_msa) + shift_msa |
|
|
tanh_gate_msa = torch.tanh(1 - gate_msa) |
|
|
x = x + tanh_gate_msa * self.ls1(self.attn(x_norm)) |
|
|
|
|
|
x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp |
|
|
tanh_gate_mlp = torch.tanh(1 - gate_mlp) |
|
|
x = x + tanh_gate_mlp * self.ls2(self.mlp(x_norm)) |
|
|
else: |
|
|
x = x + self.ls1(self.attn(self.norm1(x))) |
|
|
x = x + self.ls2(self.mlp(self.norm2(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int = 768, |
|
|
depth: int = 2, |
|
|
num_heads=8, |
|
|
mlp_ratio=4., |
|
|
qkv_bias=True, |
|
|
drop_rate=0., |
|
|
attn_drop_rate=0., |
|
|
cls_dim: int = 512, |
|
|
fusion: str = 'adaln', |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
|
act_layer = nn.GELU |
|
|
init_values = None |
|
|
|
|
|
block_function = Decoder_Block |
|
|
self.blocks = nn.ModuleList([ |
|
|
block_function( |
|
|
dim=embed_dim, |
|
|
num_heads=num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
init_values=init_values, |
|
|
drop=drop_rate, |
|
|
attn_drop=attn_drop_rate, |
|
|
norm_layer=norm_layer, |
|
|
act_layer=act_layer, |
|
|
attention_type="Attention", |
|
|
fusion=fusion, |
|
|
) for _ in range(depth) |
|
|
]) |
|
|
|
|
|
self.fusion = fusion |
|
|
cls_out = embed_dim |
|
|
|
|
|
self.cls_embed = nn.Sequential( |
|
|
nn.Linear(cls_dim, embed_dim, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(embed_dim, cls_out, bias=True),) |
|
|
|
|
|
self.sed_head = nn.Linear(embed_dim, 1, bias=True) |
|
|
self.norm = norm_layer(embed_dim) |
|
|
self.apply(self.init_weights) |
|
|
|
|
|
|
|
|
def init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
nn.init.constant_(module.bias, 0) |
|
|
nn.init.constant_(module.weight, 1.0) |
|
|
|
|
|
if self.fusion == 'adaln': |
|
|
for block in self.blocks: |
|
|
nn.init.constant_(block.adaln.weight, 0) |
|
|
nn.init.constant_(block.adaln.bias, 0) |
|
|
|
|
|
def forward(self, x, cls): |
|
|
B, L, C = x.shape |
|
|
_, N, D = cls.shape |
|
|
|
|
|
x = x.unsqueeze(1).expand(-1, N, -1, -1) |
|
|
|
|
|
x = x.reshape(B * N, L, C) |
|
|
cls = cls.reshape(B * N, D) |
|
|
|
|
|
cls = self.cls_embed(cls) |
|
|
|
|
|
shift = 0 |
|
|
if self.fusion == 'adaln': |
|
|
pass |
|
|
elif self.fusion == 'token': |
|
|
cls = cls.unsqueeze(1) |
|
|
x = torch.cat([cls, x], dim=1) |
|
|
shift = 1 |
|
|
else: |
|
|
raise NotImplementedError("unknown fusion") |
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x, cls) |
|
|
|
|
|
x = x[:, shift:] |
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
strong = self.sed_head(x) |
|
|
return strong.transpose(1, 2) |
|
|
|
|
|
|
|
|
class TSED_Wrapper(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
encoder, |
|
|
decoder, |
|
|
ft_blocks=[11, 12], |
|
|
frozen_encoder=True |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
|
|
|
print("Loading Dasheng weights for decoders...") |
|
|
for i, blk_idx in enumerate(ft_blocks): |
|
|
decoder_block = self.decoder.blocks[i] |
|
|
encoder_block = self.encoder.blocks[blk_idx] |
|
|
state_dict = copy.deepcopy(encoder_block.state_dict()) |
|
|
missing, unexpected = decoder_block.load_state_dict(state_dict, strict=False) |
|
|
if missing or unexpected: |
|
|
print(f"Block {blk_idx}:") |
|
|
if missing: |
|
|
print(f"✅ Expected missing keys: {missing}") |
|
|
if unexpected: |
|
|
print(f" Unexpected keys: {unexpected}") |
|
|
|
|
|
self.decoder.norm.load_state_dict(copy.deepcopy(self.encoder.norm.state_dict())) |
|
|
|
|
|
|
|
|
for blk_idx in sorted(ft_blocks, reverse=True): |
|
|
|
|
|
del self.encoder.blocks[blk_idx] |
|
|
|
|
|
del self.encoder.norm |
|
|
|
|
|
self.frozen_encoder = frozen_encoder |
|
|
if frozen_encoder: |
|
|
for param in self.encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward_to_spec(self, x): |
|
|
return self.encoder.forward_to_spec(x) |
|
|
|
|
|
def forward_encoder(self, x): |
|
|
if self.frozen_encoder: |
|
|
with torch.no_grad(): |
|
|
x = self.encoder(x) |
|
|
else: |
|
|
x = self.encoder(x) |
|
|
return x |
|
|
|
|
|
def forward(self, x, cls): |
|
|
x = self.forward_encoder(x) |
|
|
pred = self.decoder(x, cls) |
|
|
return pred |
|
|
|