Spaces:
Running
Running
| import torch | |
| import torchaudio | |
| import typing as T | |
| class MelspecDiscriminator(torch.nn.Module): | |
| """mel spectrogram (frequency domain) discriminator""" | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.SAMPLE_RATE = 48000 | |
| # mel filterbank transform | |
| self._melspec = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=self.SAMPLE_RATE, | |
| n_fft=2048, | |
| win_length=int(0.025 * self.SAMPLE_RATE), | |
| hop_length=int(0.010 * self.SAMPLE_RATE), | |
| n_mels=128, | |
| power=1, | |
| ) | |
| # time-frequency 2D convolutions | |
| kernel_sizes = [(7, 7), (4, 4), (4, 4), (4, 4)] | |
| strides = [(1, 2), (1, 2), (1, 2), (1, 2)] | |
| self._convs = torch.nn.ModuleList( | |
| [ | |
| torch.nn.Sequential( | |
| torch.nn.Conv2d( | |
| in_channels=1 if i == 0 else 32, | |
| out_channels=64, | |
| kernel_size=k, | |
| stride=s, | |
| padding=(1, 2), | |
| bias=False, | |
| ), | |
| torch.nn.BatchNorm2d(num_features=64), | |
| torch.nn.GLU(dim=1), | |
| ) | |
| for i, (k, s) in enumerate(zip(kernel_sizes, strides)) | |
| ] | |
| ) | |
| # output adversarial projection | |
| self._postnet = torch.nn.Conv2d( | |
| in_channels=32, | |
| out_channels=1, | |
| kernel_size=(15, 3), | |
| stride=(1, 2), | |
| ) | |
| def forward(self, x: torch.Tensor) -> T.Tuple[torch.Tensor, T.List[torch.Tensor]]: | |
| # apply the log-scale mel spectrogram transform | |
| x = torch.log(self._melspec(x) + 1e-5) | |
| # compute hidden layers and feature maps | |
| f = [] | |
| for c in self._convs: | |
| x = c(x) | |
| f.append(x) | |
| # apply the output projection and global average pooling | |
| x = self._postnet(x) | |
| x = x.mean(dim=[-2, -1]) | |
| return [(f, x)] | |