Spaces:
Running
Running
| import functools | |
| import torch | |
| import torch.nn.functional as F | |
| import crepe | |
| ########################################################################### | |
| # Model definition | |
| ########################################################################### | |
| class Crepe(torch.nn.Module): | |
| """Crepe model definition""" | |
| def __init__(self, model='full'): | |
| super().__init__() | |
| # Model-specific layer parameters | |
| if model == 'full': | |
| in_channels = [1, 1024, 128, 128, 128, 256] | |
| out_channels = [1024, 128, 128, 128, 256, 512] | |
| self.in_features = 2048 | |
| elif model == 'tiny': | |
| in_channels = [1, 128, 16, 16, 16, 32] | |
| out_channels = [128, 16, 16, 16, 32, 64] | |
| self.in_features = 256 | |
| else: | |
| raise ValueError(f'Model {model} is not supported') | |
| # Shared layer parameters | |
| kernel_sizes = [(512, 1)] + 5 * [(64, 1)] | |
| strides = [(4, 1)] + 5 * [(1, 1)] | |
| # Overload with eps and momentum conversion given by MMdnn | |
| batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, | |
| eps=0.0010000000474974513, | |
| momentum=0.0) | |
| # Layer definitions | |
| self.conv1 = torch.nn.Conv2d( | |
| in_channels=in_channels[0], | |
| out_channels=out_channels[0], | |
| kernel_size=kernel_sizes[0], | |
| stride=strides[0]) | |
| self.conv1_BN = batch_norm_fn( | |
| num_features=out_channels[0]) | |
| self.conv2 = torch.nn.Conv2d( | |
| in_channels=in_channels[1], | |
| out_channels=out_channels[1], | |
| kernel_size=kernel_sizes[1], | |
| stride=strides[1]) | |
| self.conv2_BN = batch_norm_fn( | |
| num_features=out_channels[1]) | |
| self.conv3 = torch.nn.Conv2d( | |
| in_channels=in_channels[2], | |
| out_channels=out_channels[2], | |
| kernel_size=kernel_sizes[2], | |
| stride=strides[2]) | |
| self.conv3_BN = batch_norm_fn( | |
| num_features=out_channels[2]) | |
| self.conv4 = torch.nn.Conv2d( | |
| in_channels=in_channels[3], | |
| out_channels=out_channels[3], | |
| kernel_size=kernel_sizes[3], | |
| stride=strides[3]) | |
| self.conv4_BN = batch_norm_fn( | |
| num_features=out_channels[3]) | |
| self.conv5 = torch.nn.Conv2d( | |
| in_channels=in_channels[4], | |
| out_channels=out_channels[4], | |
| kernel_size=kernel_sizes[4], | |
| stride=strides[4]) | |
| self.conv5_BN = batch_norm_fn( | |
| num_features=out_channels[4]) | |
| self.conv6 = torch.nn.Conv2d( | |
| in_channels=in_channels[5], | |
| out_channels=out_channels[5], | |
| kernel_size=kernel_sizes[5], | |
| stride=strides[5]) | |
| self.conv6_BN = batch_norm_fn( | |
| num_features=out_channels[5]) | |
| self.classifier = torch.nn.Linear( | |
| in_features=self.in_features, | |
| out_features=crepe.PITCH_BINS) | |
| def forward(self, x, embed=False): | |
| # Forward pass through first five layers | |
| x = self.embed(x) | |
| if embed: | |
| return x | |
| # Forward pass through layer six | |
| x = self.layer(x, self.conv6, self.conv6_BN) | |
| # shape=(batch, self.in_features) | |
| x = x.permute(0, 2, 1, 3).reshape(-1, self.in_features) | |
| # Compute logits | |
| return torch.sigmoid(self.classifier(x)) | |
| ########################################################################### | |
| # Forward pass utilities | |
| ########################################################################### | |
| def embed(self, x): | |
| """Map input audio to pitch embedding""" | |
| # shape=(batch, 1, 1024, 1) | |
| x = x[:, None, :, None] | |
| # Forward pass through first five layers | |
| x = self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)) | |
| x = self.layer(x, self.conv2, self.conv2_BN) | |
| x = self.layer(x, self.conv3, self.conv3_BN) | |
| x = self.layer(x, self.conv4, self.conv4_BN) | |
| x = self.layer(x, self.conv5, self.conv5_BN) | |
| return x | |
| def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)): | |
| """Forward pass through one layer""" | |
| x = F.pad(x, padding) | |
| x = conv(x) | |
| x = F.relu(x) | |
| x = batch_norm(x) | |
| return F.max_pool2d(x, (2, 1), (2, 1)) | |