Spaces:
Sleeping
Sleeping
| import torch | |
| from torchaudio.transforms import AmplitudeToDB, MelSpectrogram | |
| from models.atstframe.audio_transformer import FrameASTModel | |
| from models.transformer_wrapper import BaseModelWrapper | |
| class ATSTWrapper(BaseModelWrapper): | |
| def __init__(self, atst_dropout=0.0) -> None: | |
| super().__init__() | |
| self.atst_mel = ATSTMel() | |
| self.atst = FrameASTModel(atst_dropout=atst_dropout) | |
| self.fake_length = torch.tensor([1001]) | |
| self.cls_embed = None | |
| def mel_forward(self, x): | |
| return self.atst_mel(x) | |
| def forward(self, spec): | |
| atst_x = self.atst.get_intermediate_layers( | |
| spec, | |
| self.fake_length.to(spec).repeat(len(spec)), | |
| 1, | |
| scene=False | |
| ) | |
| return atst_x | |
| def separate_params(self): | |
| pt_params = [[], [], [], [], [], [], [], [], [], [], [], []] | |
| for k, p in self.named_parameters(): | |
| if k in ['atst.mask_embed', 'atst.pos_embed', 'atst.patch_embed.patch_embed.weight', | |
| 'atst.patch_embed.patch_embed.bias'] or "blocks.0." in k: | |
| pt_params[0].append(p) | |
| elif "blocks.1." in k: | |
| pt_params[1].append(p) | |
| elif "blocks.2." in k: | |
| pt_params[2].append(p) | |
| elif "blocks.3." in k: | |
| pt_params[3].append(p) | |
| elif "blocks.4." in k: | |
| pt_params[4].append(p) | |
| elif "blocks.5." in k: | |
| pt_params[5].append(p) | |
| elif "blocks.6." in k: | |
| pt_params[6].append(p) | |
| elif "blocks.7." in k: | |
| pt_params[7].append(p) | |
| elif "blocks.8" in k: | |
| pt_params[8].append(p) | |
| elif "blocks.9." in k: | |
| pt_params[9].append(p) | |
| elif "blocks.10." in k: | |
| pt_params[10].append(p) | |
| elif "blocks.11." in k or ".norm_frame." in k: | |
| pt_params[11].append(p) | |
| else: | |
| raise ValueError(f"Check separate params for ATST! Unknown key: {k}") | |
| return list(reversed(pt_params)) | |
| class ATSTMel(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.mel_transform = MelSpectrogram( | |
| 16000, | |
| f_min=60, | |
| f_max=7800, | |
| hop_length=160, | |
| win_length=1024, | |
| n_fft=1024, | |
| n_mels=64 | |
| ) | |
| self.amp_to_db = AmplitudeToDB(stype="power", top_db=80) | |
| self.scaler = MinMax(min=-79.6482, max=50.6842) | |
| def amp2db(self, spec): | |
| return self.amp_to_db(spec).clamp(min=-50, max=80) | |
| def forward(self, audio): | |
| with torch.autocast(device_type="cuda", enabled=False): | |
| spec = self.mel_transform(audio) | |
| spec = self.scaler(self.amp2db(spec)) | |
| spec = spec.unsqueeze(1) | |
| return spec | |
| class CustomAudioTransform: | |
| def __repr__(self): | |
| return self.__class__.__name__ + '()' | |
| class MinMax(CustomAudioTransform): | |
| def __init__(self, min, max): | |
| self.min = min | |
| self.max = max | |
| def __call__(self, input): | |
| if self.min is None: | |
| min_ = torch.min(input) | |
| max_ = torch.max(input) | |
| else: | |
| min_ = self.min | |
| max_ = self.max | |
| input = (input - min_) / (max_ - min_) * 2. - 1. | |
| return input | |