|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
from transformers import Wav2Vec2Model, Wav2Vec2Config |
|
|
from .conformer import FinalConformer |
|
|
|
|
|
class DF_Arena_1B(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.ssl_model = Wav2Vec2Model(Wav2Vec2Config.from_pretrained("facebook/wav2vec2-xls-r-1b")) |
|
|
self.ssl_model.config.output_hidden_states = True |
|
|
self.first_bn = nn.BatchNorm2d(num_features=1) |
|
|
self.selu = nn.SELU(inplace=True) |
|
|
self.fc0 = nn.Linear(1280, 1) |
|
|
self.sig = nn.Sigmoid() |
|
|
|
|
|
|
|
|
self.conformer = FinalConformer(emb_size=1280, heads=4, ffmult=4, exp_fac=2, kernel_size=31, n_encoders=4) |
|
|
|
|
|
|
|
|
self.attn_scores = nn.Linear(1280, 1, bias=False) |
|
|
|
|
|
def get_attenF1Dpooling(self, x): |
|
|
|
|
|
logits = self.attn_scores(x) |
|
|
weights = torch.softmax(logits, dim=1) |
|
|
pooled = torch.sum(weights * x, dim=1, keepdim=True) |
|
|
return pooled |
|
|
|
|
|
def get_attenF1D(self, layerResult): |
|
|
poollayerResult = [] |
|
|
fullf = [] |
|
|
for layer in layerResult: |
|
|
|
|
|
|
|
|
layery = self.get_attenF1Dpooling(layer) |
|
|
poollayerResult.append(layery) |
|
|
fullf.append(layer.unsqueeze(1)) |
|
|
|
|
|
layery = torch.cat(poollayerResult, dim=1) |
|
|
fullfeature = torch.cat(fullf, dim=1) |
|
|
return layery, fullfeature |
|
|
|
|
|
def forward(self, x): |
|
|
out_ssl = self.ssl_model(x.unsqueeze(0)) |
|
|
y0, fullfeature = self.get_attenF1D(out_ssl.hidden_states) |
|
|
y0 = self.fc0(y0) |
|
|
y0 = self.sig(y0) |
|
|
y0 = y0.view(y0.shape[0], y0.shape[1], y0.shape[2], -1) |
|
|
fullfeature = fullfeature * y0 |
|
|
fullfeature = torch.sum(fullfeature, 1) |
|
|
fullfeature = fullfeature.unsqueeze(dim=1) |
|
|
fullfeature = self.first_bn(fullfeature) |
|
|
fullfeature = self.selu(fullfeature) |
|
|
|
|
|
|
|
|
output, _ = self.conformer(fullfeature.squeeze(1)) |
|
|
|
|
|
|
|
|
return output |