Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Function | |
| from torch import tensor | |
| from transformers import Wav2Vec2FeatureExtractor, WavLMModel | |
| import transformers.models.wavlm.modeling_wavlm as wavlm | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from speechbrain.lobes.models.huggingface_transformers.huggingface import make_padding_masks | |
| class RevGrad(Function): | |
| def forward(ctx, input_, alpha_): | |
| ctx.save_for_backward(input_, alpha_) | |
| return input_ | |
| def backward(ctx, grad_output): | |
| _, alpha_ = ctx.saved_tensors | |
| grad_input = -grad_output * alpha_ if ctx.needs_input_grad[0] else None | |
| return grad_input, None | |
| revgrad = RevGrad.apply | |
| class RevGradLayer(nn.Module): | |
| def __init__(self, alpha=1.): | |
| super().__init__() | |
| self._alpha = tensor(alpha, requires_grad=False) | |
| def forward(self, x): | |
| return revgrad(x, self._alpha) | |
| class WavLMEncoderLayer(nn.Module): | |
| def __init__(self, layer_idx, config, has_relative_position_bias: bool = True): | |
| super().__init__() | |
| self.attention = wavlm.WavLMAttention( | |
| embed_dim=config.hidden_size, | |
| num_heads=config.num_attention_heads, | |
| dropout=config.attention_dropout, | |
| num_buckets=config.num_buckets, | |
| max_distance=config.max_bucket_distance, | |
| has_relative_position_bias=has_relative_position_bias, | |
| ) | |
| self.dropout = nn.Dropout(config.hidden_dropout) | |
| self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.feed_forward = wavlm.WavLMFeedForward(config) | |
| self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.config = config | |
| def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0): | |
| attn_residual = hidden_states | |
| hidden_states, attn_weights, position_bias = self.attention( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_bias=position_bias, | |
| output_attentions=output_attentions, | |
| index=index, | |
| ) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = attn_residual + hidden_states | |
| hidden_states = self.layer_norm(hidden_states) | |
| hidden_states = hidden_states + self.feed_forward(hidden_states) | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| outputs = (hidden_states, position_bias) | |
| if output_attentions: | |
| outputs += (attn_weights,) | |
| return outputs | |
| class WavLMEncoderLayerStableLayerNorm(nn.Module): | |
| def __init__(self, layer_idx, config, has_relative_position_bias: bool = True): | |
| super().__init__() | |
| self.attention = wavlm.WavLMAttention( | |
| embed_dim=config.hidden_size, | |
| num_heads=config.num_attention_heads, | |
| dropout=config.attention_dropout, | |
| num_buckets=config.num_buckets, | |
| max_distance=config.max_bucket_distance, | |
| has_relative_position_bias=has_relative_position_bias, | |
| ) | |
| self.dropout = nn.Dropout(config.hidden_dropout) | |
| self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.feed_forward = wavlm.WavLMFeedForward(config) | |
| self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.config = config | |
| def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False): | |
| attn_residual = hidden_states | |
| hidden_states = self.layer_norm(hidden_states) | |
| hidden_states, attn_weights, position_bias = self.attention( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_bias=position_bias, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = attn_residual + hidden_states | |
| hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) | |
| outputs = (hidden_states, position_bias) | |
| if output_attentions: | |
| outputs += (attn_weights,) | |
| return outputs | |
| class WavLMWrapper(nn.Module, PyTorchModelHubMixin): | |
| def __init__( | |
| self, | |
| pretrain_model="wavlm_large", | |
| hidden_dim=256, | |
| freeze_params=True, | |
| output_class_num=4, | |
| use_conv_output=True, | |
| apply_reg=False | |
| ): | |
| super().__init__() | |
| self.pretrain_model = pretrain_model | |
| self.use_conv_output = use_conv_output | |
| # Load backbone | |
| if self.pretrain_model == "wavlm": | |
| self.backbone_model = WavLMModel.from_pretrained( | |
| "microsoft/wavlm-base-plus", | |
| output_hidden_states=True, | |
| ) | |
| elif self.pretrain_model == "wavlm_large": | |
| self.processor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-large') | |
| self.backbone_model = WavLMModel.from_pretrained( | |
| "microsoft/wavlm-large", | |
| output_hidden_states=True, | |
| ) | |
| # Keep original encoder layers (no LoRA) | |
| state_dict = self.backbone_model.state_dict() | |
| self.model_config = self.backbone_model.config | |
| if self.pretrain_model == "wavlm": | |
| self.backbone_model.encoder.layers = nn.ModuleList( | |
| [WavLMEncoderLayer(i, self.model_config, has_relative_position_bias=(i == 0)) | |
| for i in range(self.model_config.num_hidden_layers)] | |
| ) | |
| else: | |
| self.backbone_model.encoder.layers = nn.ModuleList( | |
| [WavLMEncoderLayerStableLayerNorm(i, self.model_config, has_relative_position_bias=(i == 0)) | |
| for i in range(self.model_config.num_hidden_layers)] | |
| ) | |
| self.backbone_model.load_state_dict(state_dict, strict=False) | |
| # Freeze weights if requested | |
| if freeze_params: | |
| for p in self.backbone_model.parameters(): | |
| p.requires_grad = False | |
| # Conv projection layers | |
| self.model_seq = nn.Sequential( | |
| nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Conv1d(hidden_dim, hidden_dim, 1), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Conv1d(hidden_dim, hidden_dim, 1) | |
| ) | |
| # Layer weights | |
| num_layers = self.model_config.num_hidden_layers + 1 if use_conv_output else self.model_config.num_hidden_layers | |
| self.weights = nn.Parameter(torch.ones(num_layers)/num_layers) | |
| # Output heads | |
| if apply_reg: | |
| self.age_dist_layer = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, 1), | |
| nn.Sigmoid() | |
| ) | |
| else: | |
| self.age_dist_layer = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, 7) | |
| ) | |
| self.sex_layer = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, 2) | |
| ) | |
| def forward(self, x, length=None, return_feature=False, pred="age_dist_sex"): | |
| # Feature extraction | |
| if self.pretrain_model == "wavlm_large": | |
| with torch.no_grad(): | |
| signal, attention_mask = [], [] | |
| if length is not None: | |
| attention_mask = make_padding_masks(x, wav_len=length/length.max()).to(x.device) | |
| else: | |
| attention_mask = make_padding_masks(x, wav_len=torch.tensor([1]).to(x.device)).to(x.device) | |
| for idx in range(len(x)): | |
| input_vals = self.processor(x[idx], sampling_rate=16_000, return_tensors="pt", padding=True) | |
| signal.append(input_vals["input_values"][0].to(x.device)) | |
| signal = torch.stack(signal) | |
| if length is not None: | |
| length = self.get_feat_extract_output_lengths(length.detach().cpu()).cuda() | |
| if self.pretrain_model == "wavlm": | |
| x = self.backbone_model(x, output_hidden_states=True).hidden_states | |
| else: | |
| x = self.backbone_model(signal, attention_mask=attention_mask, output_hidden_states=True).hidden_states | |
| # Weighted sum of layers | |
| stacked_feature = torch.stack(x, dim=0) if self.use_conv_output else torch.stack(x, dim=0)[1:] | |
| _, *origin_shape = stacked_feature.shape | |
| stacked_feature = stacked_feature.view(stacked_feature.shape[0], -1) | |
| norm_weights = F.softmax(self.weights, dim=-1) | |
| weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0) | |
| features = weighted_feature.view(*origin_shape) | |
| # Conv projection | |
| features = self.model_seq(features.transpose(1, 2)).transpose(1, 2) | |
| # Pooling | |
| if length is not None: | |
| mean = [] | |
| for snt_id in range(features.shape[0]): | |
| actual_size = length[snt_id] | |
| mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0)) | |
| features = torch.stack(mean) | |
| else: | |
| features = torch.mean(features, dim=1) | |
| # Predictions | |
| age_pred = self.age_dist_layer(features) | |
| sex_pred = self.sex_layer(features) | |
| if return_feature: | |
| return age_pred, sex_pred, features | |
| return age_pred, sex_pred | |
| # Huggingface conv output length helper | |
| def get_feat_extract_output_lengths(self, input_length): | |
| def _conv_out_length(input_length, kernel_size, stride): | |
| return (input_length - kernel_size) // stride + 1 | |
| for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride): | |
| input_length = _conv_out_length(input_length, kernel_size, stride) | |
| return input_length | |
| def age_gender(audio_waveform_np, model, device): | |
| #numpy2tensor | |
| if isinstance(audio_waveform_np, np.ndarray): | |
| tensor = torch.from_numpy(audio_waveform_np) | |
| elif isinstance(audio_waveform_np, torch.Tensor): | |
| tensor = audio_waveform_np | |
| if tensor.dim() == 1: | |
| tensor = tensor.unsqueeze(0) | |
| tensor = tensor.to(torch.device(device)) | |
| if tensor.dtype not in (torch.float32, torch.float16): | |
| tensor = tensor.float() | |
| with torch.no_grad(): | |
| wavlm_outputs, wavlm_sex_outputs = model(tensor) | |
| age_pred = wavlm_outputs.detach().cpu().numpy().flatten() * 100.0 | |
| sex_prob = F.softmax(wavlm_sex_outputs, dim=1) | |
| sex_labels_es = ["Femenino", "Masculino"] | |
| sex_idx = int(torch.argmax(sex_prob).detach().cpu().item()) | |
| sex_pred = sex_labels_es[sex_idx] | |
| try: | |
| age_value = int(round(float(age_pred[0]))) | |
| if age_value < 20: | |
| age_group = "joven (menor de 20)" | |
| elif age_value < 35: | |
| age_group = "adulto (20–35)" | |
| elif age_value < 60: | |
| age_group = "mediana edad (35–60)" | |
| else: | |
| age_group = "mayor (60+)" | |
| except Exception: | |
| age_value = None | |
| age_group = "desconocido" | |
| return str(age_value) if age_value is not None else "N/A", sex_pred, age_group | |