|
|
import torch.nn as nn |
|
|
import torch |
|
|
|
|
|
|
|
|
class GLU(nn.Module): |
|
|
def __init__(self, input_num): |
|
|
super(GLU, self).__init__() |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
self.linear = nn.Linear(input_num, input_num) |
|
|
|
|
|
def forward(self, x): |
|
|
lin = self.linear(x.permute(0, 2, 3, 1)) |
|
|
lin = lin.permute(0, 3, 1, 2) |
|
|
sig = self.sigmoid(x) |
|
|
res = lin * sig |
|
|
return res |
|
|
|
|
|
|
|
|
class ContextGating(nn.Module): |
|
|
def __init__(self, input_num): |
|
|
super(ContextGating, self).__init__() |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
self.linear = nn.Linear(input_num, input_num) |
|
|
|
|
|
def forward(self, x): |
|
|
lin = self.linear(x.permute(0, 2, 3, 1)) |
|
|
lin = lin.permute(0, 3, 1, 2) |
|
|
sig = self.sigmoid(lin) |
|
|
res = x * sig |
|
|
return res |
|
|
|
|
|
|
|
|
class CNN(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
n_in_channel, |
|
|
activation="Relu", |
|
|
conv_dropout=0, |
|
|
kernel_size=[3, 3, 3], |
|
|
padding=[1, 1, 1], |
|
|
stride=[1, 1, 1], |
|
|
nb_filters=[64, 64, 64], |
|
|
pooling=[(1, 4), (1, 4), (1, 4)], |
|
|
normalization="batch", |
|
|
**transformer_kwargs |
|
|
): |
|
|
""" |
|
|
Initialization of CNN network s |
|
|
|
|
|
Args: |
|
|
n_in_channel: int, number of input channel |
|
|
activation: str, activation function |
|
|
conv_dropout: float, dropout |
|
|
kernel_size: kernel size |
|
|
padding: padding |
|
|
stride: list, stride |
|
|
nb_filters: number of filters |
|
|
pooling: list of tuples, time and frequency pooling |
|
|
normalization: choose between "batch" for BatchNormalization and "layer" for LayerNormalization. |
|
|
""" |
|
|
super(CNN, self).__init__() |
|
|
|
|
|
self.nb_filters = nb_filters |
|
|
cnn = nn.Sequential() |
|
|
|
|
|
def conv(i, normalization="batch", dropout=None, activ="relu"): |
|
|
nIn = n_in_channel if i == 0 else nb_filters[i - 1] |
|
|
nOut = nb_filters[i] |
|
|
cnn.add_module( |
|
|
"conv{0}".format(i), |
|
|
nn.Conv2d(nIn, nOut, kernel_size[i], stride[i], padding[i]), |
|
|
) |
|
|
if normalization == "batch": |
|
|
cnn.add_module( |
|
|
"batchnorm{0}".format(i), |
|
|
nn.BatchNorm2d(nOut, eps=0.001, momentum=0.99), |
|
|
) |
|
|
elif normalization == "layer": |
|
|
cnn.add_module("layernorm{0}".format(i), nn.GroupNorm(1, nOut)) |
|
|
|
|
|
if activ.lower() == "leakyrelu": |
|
|
cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2)) |
|
|
elif activ.lower() == "relu": |
|
|
cnn.add_module("relu{0}".format(i), nn.ReLU()) |
|
|
elif activ.lower() == "glu": |
|
|
cnn.add_module("glu{0}".format(i), GLU(nOut)) |
|
|
elif activ.lower() == "cg": |
|
|
cnn.add_module("cg{0}".format(i), ContextGating(nOut)) |
|
|
|
|
|
if dropout is not None: |
|
|
cnn.add_module("dropout{0}".format(i), nn.Dropout(dropout)) |
|
|
|
|
|
|
|
|
for i in range(len(nb_filters)): |
|
|
conv(i, normalization=normalization, dropout=conv_dropout, activ=activation) |
|
|
cnn.add_module( |
|
|
"pooling{0}".format(i), nn.AvgPool2d(pooling[i]) |
|
|
) |
|
|
|
|
|
self.cnn = cnn |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward step of the CNN module |
|
|
|
|
|
Args: |
|
|
x (Tensor): input batch of size (batch_size, n_channels, n_frames, n_freq) |
|
|
|
|
|
Returns: |
|
|
Tensor: batch embedded |
|
|
""" |
|
|
|
|
|
x = self.cnn(x) |
|
|
return x |
|
|
|