import torch.nn as nn import torch class GLU(nn.Module): def __init__(self, input_num): super(GLU, self).__init__() self.sigmoid = nn.Sigmoid() self.linear = nn.Linear(input_num, input_num) def forward(self, x): lin = self.linear(x.permute(0, 2, 3, 1)) lin = lin.permute(0, 3, 1, 2) sig = self.sigmoid(x) res = lin * sig return res class ContextGating(nn.Module): def __init__(self, input_num): super(ContextGating, self).__init__() self.sigmoid = nn.Sigmoid() self.linear = nn.Linear(input_num, input_num) def forward(self, x): lin = self.linear(x.permute(0, 2, 3, 1)) lin = lin.permute(0, 3, 1, 2) sig = self.sigmoid(lin) res = x * sig return res class CNN(nn.Module): def __init__( self, n_in_channel, activation="Relu", conv_dropout=0, kernel_size=[3, 3, 3], padding=[1, 1, 1], stride=[1, 1, 1], nb_filters=[64, 64, 64], pooling=[(1, 4), (1, 4), (1, 4)], normalization="batch", **transformer_kwargs ): """ Initialization of CNN network s Args: n_in_channel: int, number of input channel activation: str, activation function conv_dropout: float, dropout kernel_size: kernel size padding: padding stride: list, stride nb_filters: number of filters pooling: list of tuples, time and frequency pooling normalization: choose between "batch" for BatchNormalization and "layer" for LayerNormalization. """ super(CNN, self).__init__() self.nb_filters = nb_filters cnn = nn.Sequential() def conv(i, normalization="batch", dropout=None, activ="relu"): nIn = n_in_channel if i == 0 else nb_filters[i - 1] nOut = nb_filters[i] cnn.add_module( "conv{0}".format(i), nn.Conv2d(nIn, nOut, kernel_size[i], stride[i], padding[i]), ) if normalization == "batch": cnn.add_module( "batchnorm{0}".format(i), nn.BatchNorm2d(nOut, eps=0.001, momentum=0.99), ) elif normalization == "layer": cnn.add_module("layernorm{0}".format(i), nn.GroupNorm(1, nOut)) if activ.lower() == "leakyrelu": cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2)) elif activ.lower() == "relu": cnn.add_module("relu{0}".format(i), nn.ReLU()) elif activ.lower() == "glu": cnn.add_module("glu{0}".format(i), GLU(nOut)) elif activ.lower() == "cg": cnn.add_module("cg{0}".format(i), ContextGating(nOut)) if dropout is not None: cnn.add_module("dropout{0}".format(i), nn.Dropout(dropout)) # 128x862x64 for i in range(len(nb_filters)): conv(i, normalization=normalization, dropout=conv_dropout, activ=activation) cnn.add_module( "pooling{0}".format(i), nn.AvgPool2d(pooling[i]) ) # bs x tframe x mels self.cnn = cnn def forward(self, x): """ Forward step of the CNN module Args: x (Tensor): input batch of size (batch_size, n_channels, n_frames, n_freq) Returns: Tensor: batch embedded """ # conv features x = self.cnn(x) return x