|
|
import warnings |
|
|
|
|
|
import torch.nn as nn |
|
|
import torch |
|
|
from .RNN import BidirectionalGRU |
|
|
from .CNN import CNN |
|
|
|
|
|
class CRNN(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
n_in_channel=1, |
|
|
nclass=10, |
|
|
attention=True, |
|
|
activation="glu", |
|
|
dropout=0.5, |
|
|
train_cnn=True, |
|
|
rnn_type="BGRU", |
|
|
n_RNN_cell=128, |
|
|
n_layers_RNN=2, |
|
|
dropout_recurrent=0, |
|
|
cnn_integration=False, |
|
|
freeze_bn=False, |
|
|
use_embeddings=False, |
|
|
embedding_size=527, |
|
|
embedding_type="global", |
|
|
frame_emb_enc_dim=512, |
|
|
aggregation_type="global", |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Initialization of CRNN model |
|
|
|
|
|
Args: |
|
|
n_in_channel: int, number of input channel |
|
|
n_class: int, number of classes |
|
|
attention: bool, adding attention layer or not |
|
|
activation: str, activation function |
|
|
dropout: float, dropout |
|
|
train_cnn: bool, training cnn layers |
|
|
rnn_type: str, rnn type |
|
|
n_RNN_cell: int, RNN nodes |
|
|
n_layer_RNN: int, number of RNN layers |
|
|
dropout_recurrent: float, recurrent layers dropout |
|
|
cnn_integration: bool, integration of cnn |
|
|
freeze_bn: |
|
|
**kwargs: keywords arguments for CNN. |
|
|
""" |
|
|
super(CRNN, self).__init__() |
|
|
|
|
|
self.n_in_channel = n_in_channel |
|
|
self.attention = attention |
|
|
self.cnn_integration = cnn_integration |
|
|
self.freeze_bn = freeze_bn |
|
|
self.use_embeddings = use_embeddings |
|
|
self.embedding_type = embedding_type |
|
|
self.aggregation_type = aggregation_type |
|
|
|
|
|
n_in_cnn = n_in_channel |
|
|
|
|
|
if cnn_integration: |
|
|
n_in_cnn = 1 |
|
|
|
|
|
self.cnn = CNN( |
|
|
n_in_channel=n_in_cnn, activation=activation, conv_dropout=dropout, **kwargs |
|
|
) |
|
|
|
|
|
self.train_cnn = train_cnn |
|
|
if not train_cnn: |
|
|
for param in self.cnn.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
if rnn_type == "BGRU": |
|
|
nb_in = self.cnn.nb_filters[-1] |
|
|
if self.cnn_integration: |
|
|
|
|
|
nb_in = nb_in * n_in_channel |
|
|
self.rnn = BidirectionalGRU( |
|
|
n_in=nb_in, |
|
|
n_hidden=n_RNN_cell, |
|
|
dropout=dropout_recurrent, |
|
|
num_layers=n_layers_RNN, |
|
|
) |
|
|
else: |
|
|
NotImplementedError("Only BGRU supported for CRNN for now") |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.dense = nn.Linear(n_RNN_cell * 2, nclass) |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
|
|
if self.attention: |
|
|
self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass) |
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
|
|
|
if self.use_embeddings: |
|
|
if self.aggregation_type == "frame": |
|
|
self.frame_embs_encoder = nn.GRU(batch_first=True, input_size=embedding_size, |
|
|
hidden_size=512, |
|
|
bidirectional=True) |
|
|
self.shrink_emb = torch.nn.Sequential(torch.nn.Linear(2 * frame_emb_enc_dim, nb_in), |
|
|
torch.nn.LayerNorm(nb_in)) |
|
|
self.cat_tf = torch.nn.Linear(2*nb_in, nb_in) |
|
|
elif self.aggregation_type == "global": |
|
|
self.shrink_emb = torch.nn.Sequential(torch.nn.Linear(embedding_size, nb_in), |
|
|
torch.nn.LayerNorm(nb_in)) |
|
|
self.cat_tf = torch.nn.Linear(2*nb_in, nb_in) |
|
|
elif self.aggregation_type == "interpolate": |
|
|
self.cat_tf = torch.nn.Linear(nb_in+embedding_size, nb_in) |
|
|
elif self.aggregation_type == "pool1d": |
|
|
self.cat_tf = torch.nn.Linear(nb_in+embedding_size, nb_in) |
|
|
else: |
|
|
self.cat_tf = torch.nn.Linear(2*nb_in, nb_in) |
|
|
|
|
|
|
|
|
def forward(self, x, pad_mask=None, embeddings=None): |
|
|
|
|
|
x = x.transpose(1, 2).unsqueeze(1) |
|
|
|
|
|
|
|
|
if self.cnn_integration: |
|
|
bs_in, nc_in = x.size(0), x.size(1) |
|
|
x = x.view(bs_in * nc_in, 1, *x.shape[2:]) |
|
|
|
|
|
|
|
|
x = self.cnn(x) |
|
|
bs, chan, frames, freq = x.size() |
|
|
if self.cnn_integration: |
|
|
x = x.reshape(bs_in, chan * nc_in, frames, freq) |
|
|
|
|
|
if freq != 1: |
|
|
warnings.warn( |
|
|
f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq" |
|
|
) |
|
|
x = x.permute(0, 2, 1, 3) |
|
|
x = x.contiguous().view(bs, frames, chan * freq) |
|
|
else: |
|
|
x = x.squeeze(-1) |
|
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
|
|
|
if self.use_embeddings: |
|
|
if self.aggregation_type == "global": |
|
|
x = self.cat_tf(torch.cat((x, self.shrink_emb(embeddings).unsqueeze(1).repeat(1, x.shape[1], 1)), -1)) |
|
|
elif self.aggregation_type == "frame": |
|
|
|
|
|
|
|
|
last, _ = self.frame_embs_encoder(embeddings.transpose(1, 2)) |
|
|
embeddings = last[:, -1] |
|
|
x = self.cat_tf(torch.cat((x, self.shrink_emb(embeddings).unsqueeze(1).repeat(1, x.shape[1], 1)), -1)) |
|
|
elif self.aggregation_type == "interpolate": |
|
|
output_shape = (embeddings.shape[1], x.shape[1]) |
|
|
reshape_emb = torch.nn.functional.interpolate(embeddings.unsqueeze(1), size=output_shape, mode='nearest-exact').squeeze(1).transpose(1, 2) |
|
|
x = self.cat_tf(torch.cat((x, reshape_emb), -1)) |
|
|
elif self.aggregation_type == "pool1d": |
|
|
reshape_emb = torch.nn.functional.adaptive_avg_pool1d(embeddings, x.shape[1]).transpose(1, 2) |
|
|
x = self.cat_tf(torch.cat((x, reshape_emb), -1)) |
|
|
else: |
|
|
pass |
|
|
|
|
|
x = self.rnn(x) |
|
|
x = self.dropout(x) |
|
|
strong = self.dense(x) |
|
|
strong = self.sigmoid(strong) |
|
|
if self.attention: |
|
|
sof = self.dense_softmax(x) |
|
|
if not pad_mask is None: |
|
|
sof = sof.masked_fill(pad_mask.transpose(1, 2), -1e30) |
|
|
sof = self.softmax(sof) |
|
|
sof = torch.clamp(sof, min=1e-7, max=1) |
|
|
weak = (strong * sof).sum(1) / sof.sum(1) |
|
|
else: |
|
|
weak = strong.mean(1) |
|
|
return strong.transpose(1, 2), weak |
|
|
|
|
|
def train(self, mode=True): |
|
|
""" |
|
|
Override the default train() to freeze the BN parameters |
|
|
""" |
|
|
super(CRNN, self).train(mode) |
|
|
if self.freeze_bn: |
|
|
print("Freezing Mean/Var of BatchNorm2D.") |
|
|
if self.freeze_bn: |
|
|
print("Freezing Weight/Bias of BatchNorm2D.") |
|
|
if self.freeze_bn: |
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.BatchNorm2d): |
|
|
m.eval() |
|
|
if self.freeze_bn: |
|
|
m.weight.requires_grad = False |
|
|
m.bias.requires_grad = False |
|
|
|