Spaces:
Build error
Build error
| import torch | |
| from deepafx_st.models.mobilenetv2 import MobileNetV2 | |
| from deepafx_st.models.efficient_net import EfficientNet | |
| class SpectralEncoder(torch.nn.Module): | |
| def __init__( | |
| self, | |
| num_params, | |
| sample_rate, | |
| encoder_model="mobilenet_v2", | |
| embed_dim=1028, | |
| width_mult=1, | |
| min_level_db=-80, | |
| ): | |
| """Encoder operating on spectrograms. | |
| Args: | |
| num_params (int): Number of processor parameters to generate. | |
| sample_rate (float): Audio sample rate for computing melspectrogram. | |
| encoder_model (str, optional): Encoder model architecture. Default: "mobilenet_v2" | |
| embed_dim (int, optional): Dimentionality of the encoder representations. | |
| width_mult (int, optional): Encoder size. Default: 1 | |
| min_level_db (float, optional): Minimal dB value for the spectrogram. Default: -80 | |
| """ | |
| super().__init__() | |
| self.num_params = num_params | |
| self.sample_rate = sample_rate | |
| self.encoder_model = encoder_model | |
| self.embed_dim = embed_dim | |
| self.width_mult = width_mult | |
| self.min_level_db = min_level_db | |
| # load model from torch.hub | |
| if encoder_model == "mobilenet_v2": | |
| self.encoder = MobileNetV2(embed_dim=embed_dim, width_mult=width_mult) | |
| elif encoder_model == "efficient_net": | |
| self.encoder = EfficientNet.from_name( | |
| "efficientnet-b2", | |
| in_channels=1, | |
| image_size=(128, 65), | |
| include_top=False, | |
| ) | |
| self.embedding_projection = torch.nn.Conv2d( | |
| in_channels=1408, | |
| out_channels=embed_dim, | |
| kernel_size=(1, 1), | |
| stride=(1, 1), | |
| padding=(0, 0), | |
| bias=True, | |
| ) | |
| else: | |
| raise ValueError(f"Invalid encoder_model: {encoder_model}.") | |
| self.window = torch.nn.Parameter(torch.hann_window(4096)) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x (Tensor): Input waveform of shape [batch x channels x samples] | |
| Returns: | |
| e (Tensor): Latent embedding produced by Encoder. [batch x embed_dim] | |
| """ | |
| bs, chs, samp = x.size() | |
| # compute spectrogram of waveform | |
| X = torch.stft( | |
| x.view(bs, -1), | |
| 4096, | |
| 2048, | |
| window=self.window, | |
| return_complex=True, | |
| ) | |
| X_db = torch.pow(X.abs() + 1e-8, 0.3) | |
| X_db_norm = X_db | |
| # standardize (0, 1) 0.322970 0.278452 | |
| X_db_norm -= 0.322970 | |
| X_db_norm /= 0.278452 | |
| X_db_norm = X_db_norm.unsqueeze(1).permute(0, 1, 3, 2) | |
| if self.encoder_model == "mobilenet_v2": | |
| # repeat channels by 3 to fit vision model | |
| X_db_norm = X_db_norm.repeat(1, 3, 1, 1) | |
| # pass melspectrogram through encoder | |
| e = self.encoder(X_db_norm) | |
| # apply avg pooling across time for encoder embeddings | |
| e = torch.nn.functional.adaptive_avg_pool2d(e, 1).reshape(e.shape[0], -1) | |
| # normalize by L2 norm | |
| norm = torch.norm(e, p=2, dim=-1, keepdim=True) | |
| e_norm = e / norm | |
| elif self.encoder_model == "efficient_net": | |
| # Efficient Net internal downsamples by 32 on time and freq axis, then average pools the rest | |
| e = self.encoder(X_db_norm) | |
| # Adding 1x1 conv to project down or up to the requested embedding size | |
| e = self.embedding_projection(e) | |
| e = torch.squeeze(e, dim=3) | |
| e = torch.squeeze(e, dim=2) | |
| # normalize by L2 norm | |
| norm = torch.norm(e, p=2, dim=-1, keepdim=True) | |
| e_norm = e / norm | |
| return e_norm | |