asr-inference / age_gender_detector.py
AbirMessaoudi's picture
fase_1, fase_2 releases (#46)
1619dcb verified
raw
history blame
11.5 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch import tensor
from transformers import Wav2Vec2FeatureExtractor, WavLMModel
import transformers.models.wavlm.modeling_wavlm as wavlm
from huggingface_hub import PyTorchModelHubMixin
from speechbrain.lobes.models.huggingface_transformers.huggingface import make_padding_masks
class RevGrad(Function):
@staticmethod
def forward(ctx, input_, alpha_):
ctx.save_for_backward(input_, alpha_)
return input_
@staticmethod
def backward(ctx, grad_output):
_, alpha_ = ctx.saved_tensors
grad_input = -grad_output * alpha_ if ctx.needs_input_grad[0] else None
return grad_input, None
revgrad = RevGrad.apply
class RevGradLayer(nn.Module):
def __init__(self, alpha=1.):
super().__init__()
self._alpha = tensor(alpha, requires_grad=False)
def forward(self, x):
return revgrad(x, self._alpha)
class WavLMEncoderLayer(nn.Module):
def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
super().__init__()
self.attention = wavlm.WavLMAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
num_buckets=config.num_buckets,
max_distance=config.max_bucket_distance,
has_relative_position_bias=has_relative_position_bias,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = wavlm.WavLMFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.config = config
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
attn_residual = hidden_states
hidden_states, attn_weights, position_bias = self.attention(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
index=index,
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states, position_bias)
if output_attentions:
outputs += (attn_weights,)
return outputs
class WavLMEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, layer_idx, config, has_relative_position_bias: bool = True):
super().__init__()
self.attention = wavlm.WavLMAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
num_buckets=config.num_buckets,
max_distance=config.max_bucket_distance,
has_relative_position_bias=has_relative_position_bias,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = wavlm.WavLMFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.config = config
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, position_bias = self.attention(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
output_attentions=output_attentions,
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
outputs = (hidden_states, position_bias)
if output_attentions:
outputs += (attn_weights,)
return outputs
class WavLMWrapper(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
pretrain_model="wavlm_large",
hidden_dim=256,
freeze_params=True,
output_class_num=4,
use_conv_output=True,
apply_reg=False
):
super().__init__()
self.pretrain_model = pretrain_model
self.use_conv_output = use_conv_output
# Load backbone
if self.pretrain_model == "wavlm":
self.backbone_model = WavLMModel.from_pretrained(
"microsoft/wavlm-base-plus",
output_hidden_states=True,
)
elif self.pretrain_model == "wavlm_large":
self.processor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-large')
self.backbone_model = WavLMModel.from_pretrained(
"microsoft/wavlm-large",
output_hidden_states=True,
)
# Keep original encoder layers (no LoRA)
state_dict = self.backbone_model.state_dict()
self.model_config = self.backbone_model.config
if self.pretrain_model == "wavlm":
self.backbone_model.encoder.layers = nn.ModuleList(
[WavLMEncoderLayer(i, self.model_config, has_relative_position_bias=(i == 0))
for i in range(self.model_config.num_hidden_layers)]
)
else:
self.backbone_model.encoder.layers = nn.ModuleList(
[WavLMEncoderLayerStableLayerNorm(i, self.model_config, has_relative_position_bias=(i == 0))
for i in range(self.model_config.num_hidden_layers)]
)
self.backbone_model.load_state_dict(state_dict, strict=False)
# Freeze weights if requested
if freeze_params:
for p in self.backbone_model.parameters():
p.requires_grad = False
# Conv projection layers
self.model_seq = nn.Sequential(
nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv1d(hidden_dim, hidden_dim, 1),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv1d(hidden_dim, hidden_dim, 1)
)
# Layer weights
num_layers = self.model_config.num_hidden_layers + 1 if use_conv_output else self.model_config.num_hidden_layers
self.weights = nn.Parameter(torch.ones(num_layers)/num_layers)
# Output heads
if apply_reg:
self.age_dist_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
else:
self.age_dist_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 7)
)
self.sex_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2)
)
def forward(self, x, length=None, return_feature=False, pred="age_dist_sex"):
# Feature extraction
if self.pretrain_model == "wavlm_large":
with torch.no_grad():
signal, attention_mask = [], []
if length is not None:
attention_mask = make_padding_masks(x, wav_len=length/length.max()).to(x.device)
else:
attention_mask = make_padding_masks(x, wav_len=torch.tensor([1]).to(x.device)).to(x.device)
for idx in range(len(x)):
input_vals = self.processor(x[idx], sampling_rate=16_000, return_tensors="pt", padding=True)
signal.append(input_vals["input_values"][0].to(x.device))
signal = torch.stack(signal)
if length is not None:
length = self.get_feat_extract_output_lengths(length.detach().cpu()).cuda()
if self.pretrain_model == "wavlm":
x = self.backbone_model(x, output_hidden_states=True).hidden_states
else:
x = self.backbone_model(signal, attention_mask=attention_mask, output_hidden_states=True).hidden_states
# Weighted sum of layers
stacked_feature = torch.stack(x, dim=0) if self.use_conv_output else torch.stack(x, dim=0)[1:]
_, *origin_shape = stacked_feature.shape
stacked_feature = stacked_feature.view(stacked_feature.shape[0], -1)
norm_weights = F.softmax(self.weights, dim=-1)
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
features = weighted_feature.view(*origin_shape)
# Conv projection
features = self.model_seq(features.transpose(1, 2)).transpose(1, 2)
# Pooling
if length is not None:
mean = []
for snt_id in range(features.shape[0]):
actual_size = length[snt_id]
mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0))
features = torch.stack(mean)
else:
features = torch.mean(features, dim=1)
# Predictions
age_pred = self.age_dist_layer(features)
sex_pred = self.sex_layer(features)
if return_feature:
return age_pred, sex_pred, features
return age_pred, sex_pred
# Huggingface conv output length helper
def get_feat_extract_output_lengths(self, input_length):
def _conv_out_length(input_length, kernel_size, stride):
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride):
input_length = _conv_out_length(input_length, kernel_size, stride)
return input_length
def age_gender(audio_waveform_np, model, device):
#numpy2tensor
if isinstance(audio_waveform_np, np.ndarray):
tensor = torch.from_numpy(audio_waveform_np)
elif isinstance(audio_waveform_np, torch.Tensor):
tensor = audio_waveform_np
if tensor.dim() == 1:
tensor = tensor.unsqueeze(0)
tensor = tensor.to(torch.device(device))
if tensor.dtype not in (torch.float32, torch.float16):
tensor = tensor.float()
with torch.no_grad():
wavlm_outputs, wavlm_sex_outputs = model(tensor)
age_pred = wavlm_outputs.detach().cpu().numpy().flatten() * 100.0
sex_prob = F.softmax(wavlm_sex_outputs, dim=1)
sex_labels_es = ["Femenino", "Masculino"]
sex_idx = int(torch.argmax(sex_prob).detach().cpu().item())
sex_pred = sex_labels_es[sex_idx]
try:
age_value = int(round(float(age_pred[0])))
if age_value < 20:
age_group = "joven (menor de 20)"
elif age_value < 35:
age_group = "adulto (20–35)"
elif age_value < 60:
age_group = "mediana edad (35–60)"
else:
age_group = "mayor (60+)"
except Exception:
age_value = None
age_group = "desconocido"
return str(age_value) if age_value is not None else "N/A", sex_pred, age_group