Spaces:
Running
on
Zero
Running
on
Zero
| # pylint: disable=R0901 | |
| # src/models/wav2vec.py | |
| """ | |
| This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding. | |
| It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities | |
| such as feature extraction and encoding. | |
| Classes: | |
| Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding. | |
| Functions: | |
| linear_interpolation: Interpolates the features based on the sequence length. | |
| """ | |
| import torch.nn.functional as F | |
| from transformers import Wav2Vec2Model | |
| from transformers.modeling_outputs import BaseModelOutput | |
| class Wav2VecModel(Wav2Vec2Model): | |
| """ | |
| Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library. | |
| It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding. | |
| ... | |
| Attributes: | |
| base_model (Wav2Vec2Model): The base Wav2Vec2Model object. | |
| Methods: | |
| forward(input_values, seq_len, attention_mask=None, mask_time_indices=None | |
| , output_attentions=None, output_hidden_states=None, return_dict=None): | |
| Forward pass of the Wav2VecModel. | |
| It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model. | |
| feature_extract(input_values, seq_len): | |
| Extracts features from the input_values using the base model. | |
| encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None): | |
| Encodes the extracted features using the base model and returns the encoded features. | |
| """ | |
| def forward( | |
| self, | |
| input_values, | |
| seq_len, | |
| attention_mask=None, | |
| mask_time_indices=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| """ | |
| Forward pass of the Wav2Vec model. | |
| Args: | |
| self: The instance of the model. | |
| input_values: The input values (waveform) to the model. | |
| seq_len: The sequence length of the input values. | |
| attention_mask: Attention mask to be used for the model. | |
| mask_time_indices: Mask indices to be used for the model. | |
| output_attentions: If set to True, returns attentions. | |
| output_hidden_states: If set to True, returns hidden states. | |
| return_dict: If set to True, returns a BaseModelOutput instead of a tuple. | |
| Returns: | |
| The output of the Wav2Vec model. | |
| """ | |
| self.config.output_attentions = True | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| extract_features = self.feature_extractor(input_values) | |
| extract_features = extract_features.transpose(1, 2) | |
| extract_features = linear_interpolation(extract_features, seq_len=seq_len) | |
| if attention_mask is not None: | |
| # compute reduced attention_mask corresponding to feature vectors | |
| attention_mask = self._get_feature_vector_attention_mask( | |
| extract_features.shape[1], attention_mask, add_adapter=False | |
| ) | |
| hidden_states, extract_features = self.feature_projection(extract_features) | |
| hidden_states = self._mask_hidden_states( | |
| hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask | |
| ) | |
| encoder_outputs = self.encoder( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = encoder_outputs[0] | |
| if self.adapter is not None: | |
| hidden_states = self.adapter(hidden_states) | |
| if not return_dict: | |
| return (hidden_states, ) + encoder_outputs[1:] | |
| return BaseModelOutput( | |
| last_hidden_state=hidden_states, | |
| hidden_states=encoder_outputs.hidden_states, | |
| attentions=encoder_outputs.attentions, | |
| ) | |
| def feature_extract( | |
| self, | |
| input_values, | |
| seq_len, | |
| ): | |
| """ | |
| Extracts features from the input values and returns the extracted features. | |
| Parameters: | |
| input_values (torch.Tensor): The input values to be processed. | |
| seq_len (torch.Tensor): The sequence lengths of the input values. | |
| Returns: | |
| extracted_features (torch.Tensor): The extracted features from the input values. | |
| """ | |
| extract_features = self.feature_extractor(input_values) | |
| extract_features = extract_features.transpose(1, 2) | |
| extract_features = linear_interpolation(extract_features, seq_len=seq_len) | |
| return extract_features | |
| def encode( | |
| self, | |
| extract_features, | |
| attention_mask=None, | |
| mask_time_indices=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| """ | |
| Encodes the input features into the output space. | |
| Args: | |
| extract_features (torch.Tensor): The extracted features from the audio signal. | |
| attention_mask (torch.Tensor, optional): Attention mask to be used for padding. | |
| mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension. | |
| output_attentions (bool, optional): If set to True, returns the attention weights. | |
| output_hidden_states (bool, optional): If set to True, returns all hidden states. | |
| return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple. | |
| Returns: | |
| The encoded output features. | |
| """ | |
| self.config.output_attentions = True | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if attention_mask is not None: | |
| # compute reduced attention_mask corresponding to feature vectors | |
| attention_mask = self._get_feature_vector_attention_mask( | |
| extract_features.shape[1], attention_mask, add_adapter=False | |
| ) | |
| hidden_states, extract_features = self.feature_projection(extract_features) | |
| hidden_states = self._mask_hidden_states( | |
| hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask | |
| ) | |
| encoder_outputs = self.encoder( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = encoder_outputs[0] | |
| if self.adapter is not None: | |
| hidden_states = self.adapter(hidden_states) | |
| if not return_dict: | |
| return (hidden_states, ) + encoder_outputs[1:] | |
| return BaseModelOutput( | |
| last_hidden_state=hidden_states, | |
| hidden_states=encoder_outputs.hidden_states, | |
| attentions=encoder_outputs.attentions, | |
| ) | |
| def linear_interpolation(features, seq_len): | |
| """ | |
| Transpose the features to interpolate linearly. | |
| Args: | |
| features (torch.Tensor): The extracted features to be interpolated. | |
| seq_len (torch.Tensor): The sequence lengths of the features. | |
| Returns: | |
| torch.Tensor: The interpolated features. | |
| """ | |
| features = features.transpose(1, 2) | |
| output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') | |
| return output_features.transpose(1, 2) | |
| def linear_interpolation_fps(features, input_fps, output_fps, output_len=None): | |
| features = features.transpose(1, 2) # [1, C, T] | |
| seq_len = features.shape[2] / float(input_fps) | |
| if output_len is None: | |
| output_len = int(seq_len * output_fps) | |
| output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') | |
| return output_features.transpose(1, 2) |