OpenSound's picture
Upload 544 files
3b6a091 verified
import torch.nn as nn
import torch
class GLU(nn.Module):
def __init__(self, input_num):
super(GLU, self).__init__()
self.sigmoid = nn.Sigmoid()
self.linear = nn.Linear(input_num, input_num)
def forward(self, x):
lin = self.linear(x.permute(0, 2, 3, 1))
lin = lin.permute(0, 3, 1, 2)
sig = self.sigmoid(x)
res = lin * sig
return res
class ContextGating(nn.Module):
def __init__(self, input_num):
super(ContextGating, self).__init__()
self.sigmoid = nn.Sigmoid()
self.linear = nn.Linear(input_num, input_num)
def forward(self, x):
lin = self.linear(x.permute(0, 2, 3, 1))
lin = lin.permute(0, 3, 1, 2)
sig = self.sigmoid(lin)
res = x * sig
return res
class CNN(nn.Module):
def __init__(
self,
n_in_channel,
activation="Relu",
conv_dropout=0,
kernel_size=[3, 3, 3],
padding=[1, 1, 1],
stride=[1, 1, 1],
nb_filters=[64, 64, 64],
pooling=[(1, 4), (1, 4), (1, 4)],
normalization="batch",
**transformer_kwargs
):
"""
Initialization of CNN network s
Args:
n_in_channel: int, number of input channel
activation: str, activation function
conv_dropout: float, dropout
kernel_size: kernel size
padding: padding
stride: list, stride
nb_filters: number of filters
pooling: list of tuples, time and frequency pooling
normalization: choose between "batch" for BatchNormalization and "layer" for LayerNormalization.
"""
super(CNN, self).__init__()
self.nb_filters = nb_filters
cnn = nn.Sequential()
def conv(i, normalization="batch", dropout=None, activ="relu"):
nIn = n_in_channel if i == 0 else nb_filters[i - 1]
nOut = nb_filters[i]
cnn.add_module(
"conv{0}".format(i),
nn.Conv2d(nIn, nOut, kernel_size[i], stride[i], padding[i]),
)
if normalization == "batch":
cnn.add_module(
"batchnorm{0}".format(i),
nn.BatchNorm2d(nOut, eps=0.001, momentum=0.99),
)
elif normalization == "layer":
cnn.add_module("layernorm{0}".format(i), nn.GroupNorm(1, nOut))
if activ.lower() == "leakyrelu":
cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2))
elif activ.lower() == "relu":
cnn.add_module("relu{0}".format(i), nn.ReLU())
elif activ.lower() == "glu":
cnn.add_module("glu{0}".format(i), GLU(nOut))
elif activ.lower() == "cg":
cnn.add_module("cg{0}".format(i), ContextGating(nOut))
if dropout is not None:
cnn.add_module("dropout{0}".format(i), nn.Dropout(dropout))
# 128x862x64
for i in range(len(nb_filters)):
conv(i, normalization=normalization, dropout=conv_dropout, activ=activation)
cnn.add_module(
"pooling{0}".format(i), nn.AvgPool2d(pooling[i])
) # bs x tframe x mels
self.cnn = cnn
def forward(self, x):
"""
Forward step of the CNN module
Args:
x (Tensor): input batch of size (batch_size, n_channels, n_frames, n_freq)
Returns:
Tensor: batch embedded
"""
# conv features
x = self.cnn(x)
return x