FlexSED / src /models /sed_decoder.py
OpenSound's picture
Upload 544 files
3b6a091 verified
raw
history blame
6.8 kB
import torch
import torch.nn as nn
import copy
from functools import partial
from .dasheng import LayerScale, Attention, Mlp
class Decoder_Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attention_type='Attention',
fusion='adaln',
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.ls2 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.fusion = fusion
if fusion == 'adaln':
self.adaln = nn.Linear(dim, 6 * dim, bias=True)
def forward(self, x, c=None):
B, T, C = x.shape
if self.fusion == 'adaln':
ada = self.adaln(c)
(scale_msa, gate_msa, shift_msa,
scale_mlp, gate_mlp, shift_mlp) = ada.reshape(B, 6, -1).chunk(6, dim=1)
# self attention
x_norm = self.norm1(x) * (1 + scale_msa) + shift_msa
tanh_gate_msa = torch.tanh(1 - gate_msa)
x = x + tanh_gate_msa * self.ls1(self.attn(x_norm))
# mlp
x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
tanh_gate_mlp = torch.tanh(1 - gate_mlp)
x = x + tanh_gate_mlp * self.ls2(self.mlp(x_norm))
else:
x = x + self.ls1(self.attn(self.norm1(x)))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
class Decoder(nn.Module):
def __init__(
self,
embed_dim: int = 768,
depth: int = 2,
num_heads=8,
mlp_ratio=4.,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
cls_dim: int = 512,
fusion: str = 'adaln',
**kwargs
):
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU
init_values = None
block_function = Decoder_Block
self.blocks = nn.ModuleList([
block_function(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
act_layer=act_layer,
attention_type="Attention",
fusion=fusion,
) for _ in range(depth)
])
self.fusion = fusion
cls_out = embed_dim
self.cls_embed = nn.Sequential(
nn.Linear(cls_dim, embed_dim, bias=True),
nn.SiLU(),
nn.Linear(embed_dim, cls_out, bias=True),)
self.sed_head = nn.Linear(embed_dim, 1, bias=True)
self.norm = norm_layer(embed_dim)
self.apply(self.init_weights)
# self.energy_head = nn.Linear(embed_dim, 1, bias=True)
def init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
if self.fusion == 'adaln':
for block in self.blocks:
nn.init.constant_(block.adaln.weight, 0)
nn.init.constant_(block.adaln.bias, 0)
def forward(self, x, cls):
B, L, C = x.shape
_, N, D = cls.shape
# Expand x to shape (B, N, L, C)
x = x.unsqueeze(1).expand(-1, N, -1, -1)
# Reshape both tensors to (B*N, L, C) for processing
x = x.reshape(B * N, L, C)
cls = cls.reshape(B * N, D)
cls = self.cls_embed(cls)
shift = 0
if self.fusion == 'adaln':
pass
elif self.fusion == 'token':
cls = cls.unsqueeze(1)
x = torch.cat([cls, x], dim=1)
shift = 1
else:
raise NotImplementedError("unknown fusion")
for block in self.blocks:
x = block(x, cls)
x = x[:, shift:]
x = self.norm(x)
strong = self.sed_head(x)
return strong.transpose(1, 2)
class TSED_Wrapper(nn.Module):
def __init__(
self,
encoder,
decoder,
ft_blocks=[11, 12],
frozen_encoder=True
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
print("Loading Dasheng weights for decoders...")
for i, blk_idx in enumerate(ft_blocks):
decoder_block = self.decoder.blocks[i]
encoder_block = self.encoder.blocks[blk_idx]
state_dict = copy.deepcopy(encoder_block.state_dict())
missing, unexpected = decoder_block.load_state_dict(state_dict, strict=False)
if missing or unexpected:
print(f"Block {blk_idx}:")
if missing:
print(f"✅ Expected missing keys: {missing}")
if unexpected:
print(f" Unexpected keys: {unexpected}")
# Copy norm_layer weights
self.decoder.norm.load_state_dict(copy.deepcopy(self.encoder.norm.state_dict()))
# Remove the injected blocks and norm_layer from the encoder
for blk_idx in sorted(ft_blocks, reverse=True):
# Reverse to avoid index shift issues
del self.encoder.blocks[blk_idx]
# Remove encoder norm layer
del self.encoder.norm
self.frozen_encoder = frozen_encoder
if frozen_encoder:
for param in self.encoder.parameters():
param.requires_grad = False
def forward_to_spec(self, x):
return self.encoder.forward_to_spec(x)
def forward_encoder(self, x):
if self.frozen_encoder:
with torch.no_grad():
x = self.encoder(x)
else:
x = self.encoder(x)
return x
def forward(self, x, cls):
x = self.forward_encoder(x)
pred = self.decoder(x, cls)
return pred