Spaces:
Running
on
Zero
Running
on
Zero
| """Time warp module.""" | |
| import torch | |
| from funasr_detach.models.transformer.utils.nets_utils import pad_list | |
| DEFAULT_TIME_WARP_MODE = "bicubic" | |
| def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE): | |
| """Time warping using torch.interpolate. | |
| Args: | |
| x: (Batch, Time, Freq) | |
| window: time warp parameter | |
| mode: Interpolate mode | |
| """ | |
| # bicubic supports 4D or more dimension tensor | |
| org_size = x.size() | |
| if x.dim() == 3: | |
| # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq) | |
| x = x[:, None] | |
| t = x.shape[2] | |
| if t - window <= window: | |
| return x.view(*org_size) | |
| center = torch.randint(window, t - window, (1,))[0] | |
| warped = torch.randint(center - window, center + window, (1,))[0] + 1 | |
| # left: (Batch, Channel, warped, Freq) | |
| # right: (Batch, Channel, time - warped, Freq) | |
| left = torch.nn.functional.interpolate( | |
| x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False | |
| ) | |
| right = torch.nn.functional.interpolate( | |
| x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False | |
| ) | |
| if x.requires_grad: | |
| x = torch.cat([left, right], dim=-2) | |
| else: | |
| x[:, :, :warped] = left | |
| x[:, :, warped:] = right | |
| return x.view(*org_size) | |
| class TimeWarp(torch.nn.Module): | |
| """Time warping using torch.interpolate. | |
| Args: | |
| window: time warp parameter | |
| mode: Interpolate mode | |
| """ | |
| def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE): | |
| super().__init__() | |
| self.window = window | |
| self.mode = mode | |
| def extra_repr(self): | |
| return f"window={self.window}, mode={self.mode}" | |
| def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None): | |
| """Forward function. | |
| Args: | |
| x: (Batch, Time, Freq) | |
| x_lengths: (Batch,) | |
| """ | |
| if x_lengths is None or all(le == x_lengths[0] for le in x_lengths): | |
| # Note that applying same warping for each sample | |
| y = time_warp(x, window=self.window, mode=self.mode) | |
| else: | |
| # FIXME(kamo): I have no idea to batchify Timewarp | |
| ys = [] | |
| for i in range(x.size(0)): | |
| _y = time_warp( | |
| x[i][None, : x_lengths[i]], | |
| window=self.window, | |
| mode=self.mode, | |
| )[0] | |
| ys.append(_y) | |
| y = pad_list(ys, 0.0) | |
| return y, x_lengths | |