| from typing import List, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from models.scnet_unofficial.utils import create_intervals | |
| class Downsample(nn.Module): | |
| """ | |
| Downsample class implements a module for downsampling input tensors using 2D convolution. | |
| Args: | |
| - input_dim (int): Dimensionality of the input channels. | |
| - output_dim (int): Dimensionality of the output channels. | |
| - stride (int): Stride value for the convolution operation. | |
| Shapes: | |
| - Input: (B, C_in, F, T) where | |
| B is batch size, | |
| C_in is the number of input channels, | |
| F is the frequency dimension, | |
| T is the time dimension. | |
| - Output: (B, C_out, F // stride, T) where | |
| B is batch size, | |
| C_out is the number of output channels, | |
| F // stride is the downsampled frequency dimension. | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| output_dim: int, | |
| stride: int, | |
| ): | |
| """ | |
| Initializes Downsample with input dimension, output dimension, and stride. | |
| """ | |
| super().__init__() | |
| self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Performs forward pass through the Downsample module. | |
| Args: | |
| - x (torch.Tensor): Input tensor of shape (B, C_in, F, T). | |
| Returns: | |
| - torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T). | |
| """ | |
| return self.conv(x) | |
| class ConvolutionModule(nn.Module): | |
| """ | |
| ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer. | |
| Args: | |
| - input_dim (int): Dimensionality of the input features. | |
| - hidden_dim (int): Dimensionality of the hidden features. | |
| - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. | |
| - bias (bool, optional): If True, adds a learnable bias to the output. Default is False. | |
| Shapes: | |
| - Input: (B, T, D) where | |
| B is batch size, | |
| T is sequence length, | |
| D is input dimensionality. | |
| - Output: (B, T, D) where | |
| B is batch size, | |
| T is sequence length, | |
| D is input dimensionality. | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| hidden_dim: int, | |
| kernel_sizes: List[int], | |
| bias: bool = False, | |
| ) -> None: | |
| """ | |
| Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias. | |
| """ | |
| super().__init__() | |
| self.sequential = nn.Sequential( | |
| nn.GroupNorm(num_groups=1, num_channels=input_dim), | |
| nn.Conv1d( | |
| input_dim, | |
| 2 * hidden_dim, | |
| kernel_sizes[0], | |
| stride=1, | |
| padding=(kernel_sizes[0] - 1) // 2, | |
| bias=bias, | |
| ), | |
| nn.GLU(dim=1), | |
| nn.Conv1d( | |
| hidden_dim, | |
| hidden_dim, | |
| kernel_sizes[1], | |
| stride=1, | |
| padding=(kernel_sizes[1] - 1) // 2, | |
| groups=hidden_dim, | |
| bias=bias, | |
| ), | |
| nn.GroupNorm(num_groups=1, num_channels=hidden_dim), | |
| nn.SiLU(), | |
| nn.Conv1d( | |
| hidden_dim, | |
| input_dim, | |
| kernel_sizes[2], | |
| stride=1, | |
| padding=(kernel_sizes[2] - 1) // 2, | |
| bias=bias, | |
| ), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Performs forward pass through the ConvolutionModule. | |
| Args: | |
| - x (torch.Tensor): Input tensor of shape (B, T, D). | |
| Returns: | |
| - torch.Tensor: Output tensor of shape (B, T, D). | |
| """ | |
| x = x.transpose(1, 2) | |
| x = x + self.sequential(x) | |
| x = x.transpose(1, 2) | |
| return x | |
| class SDLayer(nn.Module): | |
| """ | |
| SDLayer class implements a subband decomposition layer with downsampling and convolutional modules. | |
| Args: | |
| - subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition. | |
| - input_dim (int): Dimensionality of the input channels. | |
| - output_dim (int): Dimensionality of the output channels after downsampling. | |
| - downsample_stride (int): Stride value for the downsampling operation. | |
| - n_conv_modules (int): Number of convolutional modules. | |
| - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. | |
| - bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True. | |
| Shapes: | |
| - Input: (B, Fi, T, Ci) where | |
| B is batch size, | |
| Fi is the number of input subbands, | |
| T is sequence length, and | |
| Ci is the number of input channels. | |
| - Output: (B, Fi+1, T, Ci+1) where | |
| B is batch size, | |
| Fi+1 is the number of output subbands, | |
| T is sequence length, | |
| Ci+1 is the number of output channels. | |
| """ | |
| def __init__( | |
| self, | |
| subband_interval: Tuple[float, float], | |
| input_dim: int, | |
| output_dim: int, | |
| downsample_stride: int, | |
| n_conv_modules: int, | |
| kernel_sizes: List[int], | |
| bias: bool = True, | |
| ): | |
| """ | |
| Initializes SDLayer with subband interval, input dimension, | |
| output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias. | |
| """ | |
| super().__init__() | |
| self.subband_interval = subband_interval | |
| self.downsample = Downsample(input_dim, output_dim, downsample_stride) | |
| self.activation = nn.GELU() | |
| conv_modules = [ | |
| ConvolutionModule( | |
| input_dim=output_dim, | |
| hidden_dim=output_dim // 4, | |
| kernel_sizes=kernel_sizes, | |
| bias=bias, | |
| ) | |
| for _ in range(n_conv_modules) | |
| ] | |
| self.conv_modules = nn.Sequential(*conv_modules) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Performs forward pass through the SDLayer. | |
| Args: | |
| - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). | |
| Returns: | |
| - torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1). | |
| """ | |
| B, F, T, C = x.shape | |
| x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)] | |
| x = x.permute(0, 3, 1, 2) | |
| x = self.downsample(x) | |
| x = self.activation(x) | |
| x = x.permute(0, 2, 3, 1) | |
| B, F, T, C = x.shape | |
| x = x.reshape((B * F), T, C) | |
| x = self.conv_modules(x) | |
| x = x.reshape(B, F, T, C) | |
| return x | |
| class SDBlock(nn.Module): | |
| """ | |
| SDBlock class implements a block with subband decomposition layers and global convolution. | |
| Args: | |
| - input_dim (int): Dimensionality of the input channels. | |
| - output_dim (int): Dimensionality of the output channels. | |
| - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands. | |
| - downsample_strides (List[int]): List of stride values for downsampling in each subband layer. | |
| - n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer. | |
| - kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None. | |
| Shapes: | |
| - Input: (B, Fi, T, Ci) where | |
| B is batch size, | |
| Fi is the number of input subbands, | |
| T is sequence length, | |
| Ci is the number of input channels. | |
| - Output: (B, Fi+1, T, Ci+1) where | |
| B is batch size, | |
| Fi+1 is the number of output subbands, | |
| T is sequence length, | |
| Ci+1 is the number of output channels. | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| output_dim: int, | |
| bandsplit_ratios: List[float], | |
| downsample_strides: List[int], | |
| n_conv_modules: List[int], | |
| kernel_sizes: List[int] = None, | |
| ): | |
| """ | |
| Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes. | |
| """ | |
| super().__init__() | |
| if kernel_sizes is None: | |
| kernel_sizes = [3, 3, 1] | |
| assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1." | |
| subband_intervals = create_intervals(bandsplit_ratios) | |
| self.sd_layers = nn.ModuleList( | |
| SDLayer( | |
| input_dim=input_dim, | |
| output_dim=output_dim, | |
| subband_interval=sbi, | |
| downsample_stride=dss, | |
| n_conv_modules=ncm, | |
| kernel_sizes=kernel_sizes, | |
| ) | |
| for sbi, dss, ncm in zip( | |
| subband_intervals, downsample_strides, n_conv_modules | |
| ) | |
| ) | |
| self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1) | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Performs forward pass through the SDBlock. | |
| Args: | |
| - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). | |
| Returns: | |
| - Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor. | |
| """ | |
| x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1) | |
| x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) | |
| return x, x_skip | |