|
|
from einops import rearrange |
|
|
from torch.cuda.amp import autocast |
|
|
from functools import partial |
|
|
from typing import Optional, Tuple |
|
|
import torchaudio.transforms as audio_transforms |
|
|
from einops.layers.torch import Rearrange |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from .dasheng import AudioPatchEmbed, Block |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Dasheng_Encoder(nn.Module): |
|
|
def __init__(self, |
|
|
patch_size: Tuple[int, int] = (64, 4), |
|
|
patch_stride: Tuple[int, int] = (64, 4), |
|
|
embed_dim: int = 768, |
|
|
depth: int = 12, |
|
|
num_heads=8, |
|
|
mlp_ratio=4., |
|
|
qkv_bias=True, |
|
|
drop_rate=0., |
|
|
attn_drop_rate=0., |
|
|
norm_layer=None, |
|
|
act_layer=None, |
|
|
init_values=None, |
|
|
target_length=1008, |
|
|
pooling='mean', |
|
|
time_patch_out: Optional[float] = None, |
|
|
freq_patch_out: Optional[float] = None, |
|
|
block_type='Block', |
|
|
attention_type='Attention', |
|
|
eval_avg='cat', |
|
|
n_fft: int = 512, |
|
|
n_mels: int = 64, |
|
|
hop_size: int = 160, |
|
|
win_size: int = 512, |
|
|
f_min: int = 0, |
|
|
f_max: int = 8000, |
|
|
center: bool = True, |
|
|
**kwargs): |
|
|
super().__init__() |
|
|
self.pooling = pooling |
|
|
self.embed_dim = embed_dim |
|
|
self.patch_stride = patch_stride |
|
|
self.patch_size = patch_size |
|
|
self.n_mels = n_mels |
|
|
self.eval_avg = eval_avg |
|
|
self.time_patch_out = time_patch_out |
|
|
self.freq_patch_out = freq_patch_out |
|
|
|
|
|
self.front_end = nn.Sequential( |
|
|
audio_transforms.MelSpectrogram(f_min=f_min, |
|
|
sample_rate=16000, |
|
|
win_length=win_size, |
|
|
center=center, |
|
|
n_fft=n_fft, |
|
|
f_max=f_max, |
|
|
hop_length=hop_size, |
|
|
n_mels=self.n_mels, |
|
|
power=1)) |
|
|
|
|
|
self.to_db = audio_transforms.AmplitudeToDB(stype='magnitude', top_db=kwargs.get('top_db', 120)) |
|
|
|
|
|
self.init_bn = nn.Sequential( |
|
|
Rearrange('b c f t -> b f c t'), |
|
|
nn.BatchNorm2d(self.n_mels, momentum=0.01), |
|
|
Rearrange('b f c t -> b c f t')) |
|
|
|
|
|
self.target_length = target_length |
|
|
self.patch_embed = AudioPatchEmbed(input_size=(self.n_mels, |
|
|
target_length), |
|
|
embed_dim=self.embed_dim, |
|
|
patch_size=self.patch_size, |
|
|
flatten=False, |
|
|
patch_stride=self.patch_stride) |
|
|
self.num_patches = self.patch_embed.num_patches |
|
|
|
|
|
if pooling == 'token': |
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
|
self.token_pos_embed = nn.Parameter( |
|
|
torch.randn(1, embed_dim) * .02) |
|
|
|
|
|
self.time_pos_embed = nn.Parameter( |
|
|
torch.randn(1, embed_dim, 1, self.patch_embed.grid_size[1]) * .02) |
|
|
self.freq_pos_embed = nn.Parameter( |
|
|
torch.randn(1, embed_dim, self.patch_embed.grid_size[0], 1) * .02) |
|
|
|
|
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
|
|
act_layer = act_layer or nn.GELU |
|
|
self.pos_drop = nn.Dropout(p=drop_rate) |
|
|
self.blocks = nn.Sequential(*[ |
|
|
Block( |
|
|
dim=embed_dim, |
|
|
num_heads=num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
qkv_bias=qkv_bias, |
|
|
init_values=init_values, |
|
|
drop=drop_rate, |
|
|
attn_drop=attn_drop_rate, |
|
|
norm_layer=norm_layer, |
|
|
act_layer=act_layer, |
|
|
attention_type=attention_type, |
|
|
) for _ in range(depth) |
|
|
]) |
|
|
self.norm = norm_layer(embed_dim) |
|
|
self.apply(self.init_weights) |
|
|
if hasattr(self, 'cls_token') and self.cls_token is not None: |
|
|
nn.init.normal_(self.cls_token, std=1e-6) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
nn.init.constant_(module.bias, 0) |
|
|
nn.init.constant_(module.weight, 1.0) |
|
|
|
|
|
def forward_features(self, x): |
|
|
x = self.patch_embed(x) |
|
|
b, c, f, t = x.shape |
|
|
x = x + self.time_pos_embed[:, :, :, :t] |
|
|
x = x + self.freq_pos_embed[:, :, :, :] |
|
|
x = rearrange(x, 'b c f t -> b (f t) c') |
|
|
|
|
|
|
|
|
if self.pooling == 'token': |
|
|
cls_token = self.cls_token.expand(x.shape[0], -1, -1) |
|
|
cls_token = cls_token + self.token_pos_embed[:, :] |
|
|
x = torch.cat((cls_token, x), dim=1) |
|
|
x = self.pos_drop(x) |
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def load_state_dict(self, state_dict, **kwargs): |
|
|
if 'time_pos_embed' in state_dict and self.time_pos_embed.shape != state_dict[ |
|
|
'time_pos_embed'].shape: |
|
|
print("Positional Embedding shape not the same with model, resizing!") |
|
|
self.change_pos_embedding(state_dict) |
|
|
|
|
|
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, **kwargs) |
|
|
|
|
|
if missing_keys: |
|
|
print("Missing keys:", missing_keys) |
|
|
if unexpected_keys: |
|
|
print("Unexpected keys:", unexpected_keys) |
|
|
|
|
|
def change_pos_embedding(self, state_dict): |
|
|
target_time_pos_embed_length = self.time_pos_embed.shape[-1] |
|
|
target_freq_pos_embed_length = self.freq_pos_embed.shape[-2] |
|
|
|
|
|
pretrained_time_pos_embed = state_dict['time_pos_embed'] |
|
|
pretrained_freq_pos_embed = state_dict['freq_pos_embed'] |
|
|
|
|
|
if target_freq_pos_embed_length <= pretrained_time_pos_embed.shape[-1]: |
|
|
state_dict['time_pos_embed'] = pretrained_time_pos_embed[ |
|
|
..., :target_time_pos_embed_length] |
|
|
else: |
|
|
state_dict['time_pos_embed'] = torch.nn.functional.interpolate( |
|
|
pretrained_time_pos_embed, |
|
|
size=(1, target_time_pos_embed_length), |
|
|
align_corners=False, |
|
|
mode='bilinear') |
|
|
if target_freq_pos_embed_length <= pretrained_freq_pos_embed.shape[-2]: |
|
|
state_dict[ |
|
|
'freq_pos_embed'] = pretrained_freq_pos_embed[:, :, : |
|
|
target_freq_pos_embed_length, :] |
|
|
else: |
|
|
state_dict['freq_pos_embed'] = torch.nn.functional.interpolate( |
|
|
pretrained_freq_pos_embed, |
|
|
size=(target_freq_pos_embed_length, 1), |
|
|
align_corners=False, |
|
|
mode='bilinear') |
|
|
|
|
|
def forward_to_spec(self, x): |
|
|
|
|
|
with autocast(enabled=False): |
|
|
X = self.front_end(x) |
|
|
|
|
|
|
|
|
return X |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
with autocast(enabled=False): |
|
|
x = self.to_db(x) |
|
|
x = rearrange(x, 'b f t -> b 1 f t') |
|
|
x = self.init_bn(x) |
|
|
x = self.forward_features(x) |
|
|
return x |