Spaces:
Running
on
Zero
Running
on
Zero
| """SpecAugment module.""" | |
| from typing import Optional | |
| from typing import Sequence | |
| from typing import Union | |
| from funasr_detach.models.specaug.mask_along_axis import MaskAlongAxis | |
| from funasr_detach.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth | |
| from funasr_detach.models.specaug.mask_along_axis import MaskAlongAxisLFR | |
| from funasr_detach.models.specaug.time_warp import TimeWarp | |
| from funasr_detach.register import tables | |
| import torch.nn as nn | |
| class SpecAug(nn.Module): | |
| """Implementation of SpecAug. | |
| Reference: | |
| Daniel S. Park et al. | |
| "SpecAugment: A Simple Data | |
| Augmentation Method for Automatic Speech Recognition" | |
| .. warning:: | |
| When using cuda mode, time_warp doesn't have reproducibility | |
| due to `torch.nn.functional.interpolate`. | |
| """ | |
| def __init__( | |
| self, | |
| apply_time_warp: bool = True, | |
| time_warp_window: int = 5, | |
| time_warp_mode: str = "bicubic", | |
| apply_freq_mask: bool = True, | |
| freq_mask_width_range: Union[int, Sequence[int]] = (0, 20), | |
| num_freq_mask: int = 2, | |
| apply_time_mask: bool = True, | |
| time_mask_width_range: Optional[Union[int, Sequence[int]]] = None, | |
| time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None, | |
| num_time_mask: int = 2, | |
| ): | |
| if not apply_time_warp and not apply_time_mask and not apply_freq_mask: | |
| raise ValueError( | |
| "Either one of time_warp, time_mask, or freq_mask should be applied" | |
| ) | |
| if ( | |
| apply_time_mask | |
| and (time_mask_width_range is not None) | |
| and (time_mask_width_ratio_range is not None) | |
| ): | |
| raise ValueError( | |
| 'Either one of "time_mask_width_range" or ' | |
| '"time_mask_width_ratio_range" can be used' | |
| ) | |
| super().__init__() | |
| self.apply_time_warp = apply_time_warp | |
| self.apply_freq_mask = apply_freq_mask | |
| self.apply_time_mask = apply_time_mask | |
| if apply_time_warp: | |
| self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode) | |
| else: | |
| self.time_warp = None | |
| if apply_freq_mask: | |
| self.freq_mask = MaskAlongAxis( | |
| dim="freq", | |
| mask_width_range=freq_mask_width_range, | |
| num_mask=num_freq_mask, | |
| ) | |
| else: | |
| self.freq_mask = None | |
| if apply_time_mask: | |
| if time_mask_width_range is not None: | |
| self.time_mask = MaskAlongAxis( | |
| dim="time", | |
| mask_width_range=time_mask_width_range, | |
| num_mask=num_time_mask, | |
| ) | |
| elif time_mask_width_ratio_range is not None: | |
| self.time_mask = MaskAlongAxisVariableMaxWidth( | |
| dim="time", | |
| mask_width_ratio_range=time_mask_width_ratio_range, | |
| num_mask=num_time_mask, | |
| ) | |
| else: | |
| raise ValueError( | |
| 'Either one of "time_mask_width_range" or ' | |
| '"time_mask_width_ratio_range" should be used.' | |
| ) | |
| else: | |
| self.time_mask = None | |
| def forward(self, x, x_lengths=None): | |
| if self.time_warp is not None: | |
| x, x_lengths = self.time_warp(x, x_lengths) | |
| if self.freq_mask is not None: | |
| x, x_lengths = self.freq_mask(x, x_lengths) | |
| if self.time_mask is not None: | |
| x, x_lengths = self.time_mask(x, x_lengths) | |
| return x, x_lengths | |
| class SpecAugLFR(nn.Module): | |
| """Implementation of SpecAug. | |
| lfr_rate:low frame rate | |
| """ | |
| def __init__( | |
| self, | |
| apply_time_warp: bool = True, | |
| time_warp_window: int = 5, | |
| time_warp_mode: str = "bicubic", | |
| apply_freq_mask: bool = True, | |
| freq_mask_width_range: Union[int, Sequence[int]] = (0, 20), | |
| num_freq_mask: int = 2, | |
| lfr_rate: int = 0, | |
| apply_time_mask: bool = True, | |
| time_mask_width_range: Optional[Union[int, Sequence[int]]] = None, | |
| time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None, | |
| num_time_mask: int = 2, | |
| ): | |
| if not apply_time_warp and not apply_time_mask and not apply_freq_mask: | |
| raise ValueError( | |
| "Either one of time_warp, time_mask, or freq_mask should be applied" | |
| ) | |
| if ( | |
| apply_time_mask | |
| and (time_mask_width_range is not None) | |
| and (time_mask_width_ratio_range is not None) | |
| ): | |
| raise ValueError( | |
| 'Either one of "time_mask_width_range" or ' | |
| '"time_mask_width_ratio_range" can be used' | |
| ) | |
| super().__init__() | |
| self.apply_time_warp = apply_time_warp | |
| self.apply_freq_mask = apply_freq_mask | |
| self.apply_time_mask = apply_time_mask | |
| if apply_time_warp: | |
| self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode) | |
| else: | |
| self.time_warp = None | |
| if apply_freq_mask: | |
| self.freq_mask = MaskAlongAxisLFR( | |
| dim="freq", | |
| mask_width_range=freq_mask_width_range, | |
| num_mask=num_freq_mask, | |
| lfr_rate=lfr_rate + 1, | |
| ) | |
| else: | |
| self.freq_mask = None | |
| if apply_time_mask: | |
| if time_mask_width_range is not None: | |
| self.time_mask = MaskAlongAxisLFR( | |
| dim="time", | |
| mask_width_range=time_mask_width_range, | |
| num_mask=num_time_mask, | |
| lfr_rate=lfr_rate + 1, | |
| ) | |
| elif time_mask_width_ratio_range is not None: | |
| self.time_mask = MaskAlongAxisVariableMaxWidth( | |
| dim="time", | |
| mask_width_ratio_range=time_mask_width_ratio_range, | |
| num_mask=num_time_mask, | |
| ) | |
| else: | |
| raise ValueError( | |
| 'Either one of "time_mask_width_range" or ' | |
| '"time_mask_width_ratio_range" should be used.' | |
| ) | |
| else: | |
| self.time_mask = None | |
| def forward(self, x, x_lengths=None): | |
| if self.time_warp is not None: | |
| x, x_lengths = self.time_warp(x, x_lengths) | |
| if self.freq_mask is not None: | |
| x, x_lengths = self.freq_mask(x, x_lengths) | |
| if self.time_mask is not None: | |
| x, x_lengths = self.time_mask(x, x_lengths) | |
| return x, x_lengths | |