| import torch | |
| import speechbrain as sb | |
| class FeatureScaler(torch.nn.Module): | |
| def __init__(self, num_in, scale): | |
| super().__init__() | |
| self.scaler = torch.ones((num_in,))* scale | |
| def forward(self, x): | |
| return x * self.scaler | |
| class CustomInterface(sb.pretrained.interfaces.Pretrained): | |
| MODULES_NEEDED = ["normalizer"] | |
| HPARAMS_NEEDED = ["feature_extractor"] | |
| def feats_from_audio(self, audio, lengths=torch.tensor([1.0])): | |
| feats = self.hparams.feature_extractor(audio) | |
| normalized = self.mods.normalizer(feats, lengths) | |
| scaled = self.mods.feature_scaler(normalized) | |
| return scaled | |
| def feats_from_file(self, path): | |
| audio = self.load_audio(path) | |
| return self.feats_from_audio(audio.unsqueeze(0)).squeeze(0) | |