Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from transformers.activations import get_activation | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from transformers.models.wav2vec2.modeling_wav2vec2 import ( | |
| Wav2Vec2Model, | |
| Wav2Vec2PreTrainedModel, | |
| ) | |
| _HIDDEN_STATES_START_POSITION = 2 | |
| class ClassificationHead(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| print(f"classifier_proj_size: {config.classifier_proj_size}") | |
| self.dense = nn.Linear(config.hidden_size, config.classifier_proj_size) | |
| self.layer_norm = nn.LayerNorm(config.classifier_proj_size) | |
| self.dropout = nn.Dropout(config.final_dropout) | |
| self.out_proj = nn.Linear(config.classifier_proj_size, config.num_labels) | |
| print(f"Head activation: {config.head_activation}") | |
| self.activation = get_activation(config.head_activation) | |
| def forward(self, features, **kwargs): | |
| x = features | |
| x = self.dense(x) | |
| x = self.layer_norm(x) | |
| x = self.activation(x) | |
| x = self.dropout(x) | |
| x = self.out_proj(x) | |
| return x | |
| class EmotionModel(Wav2Vec2PreTrainedModel): | |
| """Speech emotion classifier.""" | |
| def __init__(self, config, counts: Optional[dict[int, int]] = None): | |
| super().__init__(config) | |
| self.config = config | |
| self.wav2vec2 = Wav2Vec2Model(config) | |
| self.classifier = ClassificationHead(config) | |
| num_layers = ( | |
| config.num_hidden_layers + 1 | |
| ) # transformer layers + input embeddings | |
| if config.use_weighted_layer_sum: | |
| self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) | |
| self.init_weights() | |
| # counts が与えられている場合、クラスの重みを計算 | |
| if counts is not None: | |
| print(f"Using class weights: {counts}") | |
| counts_list = [counts[i] for i in range(config.num_labels)] | |
| counts_tensor = torch.tensor( | |
| counts_list, dtype=torch.float, device="cuda:0" | |
| ) | |
| total_samples = counts_tensor.sum() | |
| class_weights = total_samples / (config.num_labels * counts_tensor) | |
| # 重みを正規化(任意) | |
| class_weights = class_weights / class_weights.sum() * config.num_labels | |
| self.class_weights = class_weights | |
| else: | |
| self.class_weights = None # counts がない場合は None に設定 | |
| def forward( | |
| self, | |
| input_values: Optional[torch.Tensor], | |
| attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| ): | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| output_hidden_states = ( | |
| True if self.config.use_weighted_layer_sum else output_hidden_states | |
| ) | |
| # print(f"output_hidden_states: {output_hidden_states}") | |
| outputs = self.wav2vec2( | |
| input_values, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| if self.config.use_weighted_layer_sum: | |
| hidden_states = outputs[_HIDDEN_STATES_START_POSITION] | |
| hidden_states = torch.stack(hidden_states, dim=1) | |
| norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) | |
| hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) | |
| else: | |
| hidden_states = outputs[0] | |
| if attention_mask is None: | |
| pooled_output = hidden_states.mean(dim=1) | |
| else: | |
| padding_mask = self._get_feature_vector_attention_mask( | |
| hidden_states.shape[1], attention_mask | |
| ) | |
| hidden_states[~padding_mask] = 0.0 | |
| pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view( | |
| -1, 1 | |
| ) | |
| logits = self.classifier(pooled_output) | |
| loss = None | |
| if labels is not None: | |
| # CrossEntropyLoss に重みを適用(class_weights が None でも機能する) | |
| loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) | |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def freeze_base_model(self): | |
| r"""Freeze base model.""" | |
| for param in self.wav2vec2.parameters(): | |
| param.requires_grad = False | |
| def freeze_feature_encoder(self): | |
| r"""Freeze feature extractor.""" | |
| self.wav2vec2.freeze_feature_encoder() | |