Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Tuple | |
| from pytorch_wpe import wpe_one_iteration | |
| import torch | |
| from torch_complex.tensor import ComplexTensor | |
| from funasr_detach.frontends.utils.mask_estimator import MaskEstimator | |
| from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask | |
| class DNN_WPE(torch.nn.Module): | |
| def __init__( | |
| self, | |
| wtype: str = "blstmp", | |
| widim: int = 257, | |
| wlayers: int = 3, | |
| wunits: int = 300, | |
| wprojs: int = 320, | |
| dropout_rate: float = 0.0, | |
| taps: int = 5, | |
| delay: int = 3, | |
| use_dnn_mask: bool = True, | |
| iterations: int = 1, | |
| normalization: bool = False, | |
| ): | |
| super().__init__() | |
| self.iterations = iterations | |
| self.taps = taps | |
| self.delay = delay | |
| self.normalization = normalization | |
| self.use_dnn_mask = use_dnn_mask | |
| self.inverse_power = True | |
| if self.use_dnn_mask: | |
| self.mask_est = MaskEstimator( | |
| wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1 | |
| ) | |
| def forward( | |
| self, data: ComplexTensor, ilens: torch.LongTensor | |
| ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: | |
| """The forward function | |
| Notation: | |
| B: Batch | |
| C: Channel | |
| T: Time or Sequence length | |
| F: Freq or Some dimension of the feature vector | |
| Args: | |
| data: (B, C, T, F) | |
| ilens: (B,) | |
| Returns: | |
| data: (B, C, T, F) | |
| ilens: (B,) | |
| """ | |
| # (B, T, C, F) -> (B, F, C, T) | |
| enhanced = data = data.permute(0, 3, 2, 1) | |
| mask = None | |
| for i in range(self.iterations): | |
| # Calculate power: (..., C, T) | |
| power = enhanced.real**2 + enhanced.imag**2 | |
| if i == 0 and self.use_dnn_mask: | |
| # mask: (B, F, C, T) | |
| (mask,), _ = self.mask_est(enhanced, ilens) | |
| if self.normalization: | |
| # Normalize along T | |
| mask = mask / mask.sum(dim=-1)[..., None] | |
| # (..., C, T) * (..., C, T) -> (..., C, T) | |
| power = power * mask | |
| # Averaging along the channel axis: (..., C, T) -> (..., T) | |
| power = power.mean(dim=-2) | |
| # enhanced: (..., C, T) -> (..., C, T) | |
| enhanced = wpe_one_iteration( | |
| data.contiguous(), | |
| power, | |
| taps=self.taps, | |
| delay=self.delay, | |
| inverse_power=self.inverse_power, | |
| ) | |
| enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0) | |
| # (B, F, C, T) -> (B, T, C, F) | |
| enhanced = enhanced.permute(0, 3, 2, 1) | |
| if mask is not None: | |
| mask = mask.transpose(-1, -3) | |
| return enhanced, ilens, mask | |