|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """helper function for activation checkpointing""" | 
					
						
						|  |  | 
					
						
						|  | from typing import Union, Dict, Callable | 
					
						
						|  | from functools import partial | 
					
						
						|  | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | 
					
						
						|  | checkpoint_wrapper, | 
					
						
						|  | offload_wrapper, | 
					
						
						|  | CheckpointImpl, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """cascade basic blocks""" | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  | import backoff | 
					
						
						|  | import random | 
					
						
						|  | import numpy as np | 
					
						
						|  | from typing import Optional, Tuple, Union | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  | from torch import Tensor | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """ConformerEncoder Module""" | 
					
						
						|  |  | 
					
						
						|  | from typing import Optional, Tuple, List, Literal | 
					
						
						|  | import abc | 
					
						
						|  | import math | 
					
						
						|  | import numpy as np | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn, Tensor | 
					
						
						|  |  | 
					
						
						|  | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper | 
					
						
						|  | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def validate_checkpointing_config(activation_checkpointing): | 
					
						
						|  | """validate activation checkpointing configuration""" | 
					
						
						|  | if isinstance(activation_checkpointing, str): | 
					
						
						|  | assert activation_checkpointing in ( | 
					
						
						|  | "", | 
					
						
						|  | "checkpoint", | 
					
						
						|  | "offload", | 
					
						
						|  | ), "activation_checkpointing has to be a dict or a str in ('', 'checkpoint', 'offload')." | 
					
						
						|  | elif isinstance(activation_checkpointing, dict): | 
					
						
						|  | assert activation_checkpointing.get("module", "transformer") in ( | 
					
						
						|  | "transformer", | 
					
						
						|  | "attention", | 
					
						
						|  | ), "module in activation_checkpointing has to be in ('transformer', 'attention')." | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError("activation_checkpointing has to be a str or dict.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def embedding_checkpoint_wrapper( | 
					
						
						|  | activation_checkpointing: Union[str, Dict], | 
					
						
						|  | ) -> Callable: | 
					
						
						|  | """return encoder embedding activation checkpoint wrapper""" | 
					
						
						|  | validate_checkpointing_config(activation_checkpointing) | 
					
						
						|  |  | 
					
						
						|  | if isinstance(activation_checkpointing, str): | 
					
						
						|  | if activation_checkpointing: | 
					
						
						|  | if activation_checkpointing == "offload": | 
					
						
						|  | return offload_wrapper | 
					
						
						|  | return partial(checkpoint_wrapper) | 
					
						
						|  | return lambda x: x | 
					
						
						|  |  | 
					
						
						|  | if isinstance(activation_checkpointing, dict): | 
					
						
						|  | enabled = activation_checkpointing.get("embed", False) | 
					
						
						|  | if enabled: | 
					
						
						|  | offloading = activation_checkpointing.get("offload", False) | 
					
						
						|  | if offloading: | 
					
						
						|  | return offload_wrapper | 
					
						
						|  | impl = ( | 
					
						
						|  | CheckpointImpl.REENTRANT | 
					
						
						|  | if activation_checkpointing.get("reentrant", False) | 
					
						
						|  | else CheckpointImpl.NO_REENTRANT | 
					
						
						|  | ) | 
					
						
						|  | return partial(checkpoint_wrapper, checkpoint_impl=impl) | 
					
						
						|  | return lambda x: x | 
					
						
						|  | raise ValueError("Invalid activation_checkpointing config") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def encoder_checkpoint_wrapper( | 
					
						
						|  | activation_checkpointing: Union[str, Dict], | 
					
						
						|  | layer_cls: type, | 
					
						
						|  | idx: int = 0, | 
					
						
						|  | ) -> Callable: | 
					
						
						|  | """return encoder activation checkpoint wrapper""" | 
					
						
						|  | validate_checkpointing_config(activation_checkpointing) | 
					
						
						|  |  | 
					
						
						|  | if isinstance(activation_checkpointing, str): | 
					
						
						|  | if activation_checkpointing: | 
					
						
						|  | if activation_checkpointing == "offload": | 
					
						
						|  | return offload_wrapper | 
					
						
						|  | return partial(checkpoint_wrapper) | 
					
						
						|  | return lambda x: x | 
					
						
						|  |  | 
					
						
						|  | if isinstance(activation_checkpointing, dict): | 
					
						
						|  | target_layer_cls = activation_checkpointing.get("module", "transformer") | 
					
						
						|  | if target_layer_cls.lower() == "transformer": | 
					
						
						|  | target_layer_cls = ( | 
					
						
						|  | "EncoderLayer", | 
					
						
						|  | "ConformerEncoderLayer", | 
					
						
						|  | ) | 
					
						
						|  | elif target_layer_cls.lower() == "attention": | 
					
						
						|  | target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention") | 
					
						
						|  | checkpointing_interval = activation_checkpointing.get("interval", 1) | 
					
						
						|  | offloading = activation_checkpointing.get("offload", False) | 
					
						
						|  | impl = ( | 
					
						
						|  | CheckpointImpl.REENTRANT | 
					
						
						|  | if activation_checkpointing.get("reentrant", True) | 
					
						
						|  | else CheckpointImpl.NO_REENTRANT | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if idx % checkpointing_interval == 0 and layer_cls.__name__ in target_layer_cls: | 
					
						
						|  | if offloading: | 
					
						
						|  | return offload_wrapper | 
					
						
						|  | return partial(checkpoint_wrapper, checkpoint_impl=impl) | 
					
						
						|  | return lambda x: x | 
					
						
						|  |  | 
					
						
						|  | raise ValueError("Invalid activation_checkpointing config") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def attn_checkpointing(activation_checkpointing: Union[str, Dict], i) -> Union[str, Dict]: | 
					
						
						|  | """return activation checkpointing config for attention layer""" | 
					
						
						|  | if isinstance(activation_checkpointing, str): | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | if isinstance(activation_checkpointing, dict): | 
					
						
						|  | target_layer_cls = activation_checkpointing.get("module", "transformer") | 
					
						
						|  | checkpointing_interval = activation_checkpointing.get("interval", 1) | 
					
						
						|  | if target_layer_cls == "attention" and i % checkpointing_interval == 0: | 
					
						
						|  | return activation_checkpointing | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  | raise ValueError("Invalid activation_checkpointing config") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Block(nn.Module): | 
					
						
						|  | """Block abstract module""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, input_size, output_size): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.input_size = input_size | 
					
						
						|  | self.output_size = output_size | 
					
						
						|  |  | 
					
						
						|  | def get_activation(name="relu"): | 
					
						
						|  | """Select an activation function by name | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | name: str | 
					
						
						|  | activation function name, | 
					
						
						|  | one of ["relu", "gelu", "swish", "sigmoid"], | 
					
						
						|  | default "relu". | 
					
						
						|  | """ | 
					
						
						|  | name = name.lower() | 
					
						
						|  | if name == "relu": | 
					
						
						|  | return nn.ReLU(inplace=True) | 
					
						
						|  | if name == "gelu": | 
					
						
						|  | return nn.GELU() | 
					
						
						|  | if name == "swish": | 
					
						
						|  | return Swish() | 
					
						
						|  | if name == "sigmoid": | 
					
						
						|  | return torch.nn.Sigmoid() | 
					
						
						|  | return nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): | 
					
						
						|  | """ | 
					
						
						|  | The function is very important for Transformer Transducer Streaming mode | 
					
						
						|  | Args: | 
					
						
						|  | xs_len (int): sequence length | 
					
						
						|  | chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] | 
					
						
						|  | left_window (int): how many left chunks can be seen | 
					
						
						|  | right_window (int): how many right chunks can be seen. It is used for chunk overlap model. | 
					
						
						|  | Returns: | 
					
						
						|  | mask (torch.Tensor): a mask tensor for streaming model | 
					
						
						|  | Torch 1.0.1 | 
					
						
						|  | tensor([[1., 1., 0., 0.], | 
					
						
						|  | [0., 1., 1., 0.], | 
					
						
						|  | [0., 0., 1., 1.]]) | 
					
						
						|  | Torch 1.4.1 | 
					
						
						|  | tensor([[True., True., False., False.], | 
					
						
						|  | [False., True., True., False.], | 
					
						
						|  | [False., False., True., True.]]) | 
					
						
						|  | """ | 
					
						
						|  | chunk_start_idx = torch.Tensor( | 
					
						
						|  | chunk_start_idx | 
					
						
						|  | ).long() | 
					
						
						|  | start_pad = torch.nn.functional.pad( | 
					
						
						|  | chunk_start_idx, (1, 0) | 
					
						
						|  | ) | 
					
						
						|  | end_pad = torch.nn.functional.pad( | 
					
						
						|  | chunk_start_idx, (0, 1), value=x_len | 
					
						
						|  | ) | 
					
						
						|  | seq_range = torch.arange(0, x_len).unsqueeze(-1) | 
					
						
						|  | idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] | 
					
						
						|  | boundary = end_pad[idx] | 
					
						
						|  | seq_range_expand = ( | 
					
						
						|  | torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) | 
					
						
						|  | ) | 
					
						
						|  | idx_left = idx - left_window | 
					
						
						|  | idx_left[idx_left < 0] = 0 | 
					
						
						|  | boundary_left = start_pad[idx_left] | 
					
						
						|  | mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) | 
					
						
						|  | idx_right = idx + right_window | 
					
						
						|  | idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) | 
					
						
						|  | boundary_right = end_pad[idx_right] | 
					
						
						|  | mask_right = seq_range_expand < boundary_right.unsqueeze(-1) | 
					
						
						|  | return mask_left & mask_right | 
					
						
						|  |  | 
					
						
						|  | class Swish(nn.Module): | 
					
						
						|  | """Implement Swish activation module. | 
					
						
						|  | From https://arxiv.org/pdf/2005.03191.pdf | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.act_fn = nn.Sigmoid() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: Tensor) -> Tensor: | 
					
						
						|  | """Apply Swish function | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | Input. | 
					
						
						|  | """ | 
					
						
						|  | return x * self.act_fn(x) | 
					
						
						|  |  | 
					
						
						|  | class GLU(nn.Module): | 
					
						
						|  | """Implement Gated Linear Unit (GLU) module""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.dim = dim | 
					
						
						|  | self.act_name = act_name.lower() | 
					
						
						|  |  | 
					
						
						|  | if self.act_name == "relu": | 
					
						
						|  | self.act_fn = nn.ReLU(inplace=True) | 
					
						
						|  | elif self.act_name == "gelu": | 
					
						
						|  | self.act_fn = nn.GELU() | 
					
						
						|  | elif self.act_name == "swish": | 
					
						
						|  | self.act_fn = Swish() | 
					
						
						|  | elif self.act_name == "sigmoid": | 
					
						
						|  | self.act_fn = nn.Sigmoid() | 
					
						
						|  | else: | 
					
						
						|  | self.act_fn = nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: Tensor) -> Tensor: | 
					
						
						|  | """GLU forward | 
					
						
						|  | Apply Swish function on the first half of input matrices | 
					
						
						|  | with sigmoid of the second half. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | Input. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | half_x, gate = x.chunk(2, dim=self.dim) | 
					
						
						|  | return half_x * self.act_fn(gate) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GLUPointWiseConv(nn.Module): | 
					
						
						|  | """GLUPointWiseConv module | 
					
						
						|  | used for conformer architecture, | 
					
						
						|  | for more details see: | 
					
						
						|  | https://arxiv.org/pdf/2005.08100v1.pdf | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input_dim: int | 
					
						
						|  | input channel size. | 
					
						
						|  | output_dim: int | 
					
						
						|  | output channel size. | 
					
						
						|  | kernel_size: int | 
					
						
						|  | kernel size | 
					
						
						|  | glu_type: str, optional | 
					
						
						|  | activation function one of | 
					
						
						|  | ["sigmoid", "relu", "gelu"] | 
					
						
						|  | default "sigmoid". | 
					
						
						|  | bias_in_glu: bool, optional | 
					
						
						|  | use addtive bias in glu | 
					
						
						|  | causal: bool, optional | 
					
						
						|  | if set to True, padding is set to the half of | 
					
						
						|  | kernel size, ie, convolution can't see future frames. | 
					
						
						|  | default False. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, input_dim, output_dim, kernel_size, glu_type="sigmoid", bias_in_glu=True, causal=False | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.glu_type = glu_type | 
					
						
						|  | self.output_dim = output_dim | 
					
						
						|  | self.bias_in_glu = bias_in_glu | 
					
						
						|  | if causal: | 
					
						
						|  | self.ext_pw_conv_1d = nn.Conv1d( | 
					
						
						|  | input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.ext_pw_conv_1d = nn.Conv1d( | 
					
						
						|  | input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1) // 2 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if glu_type == "sigmoid": | 
					
						
						|  | self.glu_act = nn.Sigmoid() | 
					
						
						|  | elif glu_type == "relu": | 
					
						
						|  | self.glu_act = nn.ReLU() | 
					
						
						|  | elif glu_type == "gelu": | 
					
						
						|  | self.glu_act = nn.GELU() | 
					
						
						|  | elif glu_type == "swish": | 
					
						
						|  | self.glu_act = Swish() | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Unsupported activation type {self.glu_act}") | 
					
						
						|  |  | 
					
						
						|  | if bias_in_glu: | 
					
						
						|  | self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) | 
					
						
						|  | self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  | Args: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | input tensor | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | x = x.permute([0, 2, 1]) | 
					
						
						|  | x = self.ext_pw_conv_1d(x) | 
					
						
						|  | if self.glu_type == "bilinear": | 
					
						
						|  | if self.bias_in_glu: | 
					
						
						|  | x = (x[:, 0 : self.output_dim, :] + self.b1) * ( | 
					
						
						|  | x[:, self.output_dim : self.output_dim * 2, :] + self.b2 | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | x = (x[:, 0 : self.output_dim, :]) * ( | 
					
						
						|  | x[:, self.output_dim : self.output_dim * 2, :] | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | if self.bias_in_glu: | 
					
						
						|  | x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act( | 
					
						
						|  | x[:, self.output_dim : self.output_dim * 2, :] + self.b2 | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | x = (x[:, 0 : self.output_dim, :]) * self.glu_act( | 
					
						
						|  | x[:, self.output_dim : self.output_dim * 2, :] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | x = x.permute([0, 2, 1]) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DepthWiseSeperableConv1d(nn.Module): | 
					
						
						|  | """DepthWiseSeperableConv1d module used in Convnet module | 
					
						
						|  | for the conformer, for more details see: | 
					
						
						|  | https://arxiv.org/pdf/2005.08100v1.pdf | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input_dim: int | 
					
						
						|  | input channel size. | 
					
						
						|  | depthwise_seperable_out_channel: int | 
					
						
						|  | if set different to 0, the number of depthwise_seperable_out_channel | 
					
						
						|  | will be used as a channel_out of the second conv1d layer. | 
					
						
						|  | otherwise, it equal to 0, the second conv1d layer is skipped. | 
					
						
						|  | kernel_size: int | 
					
						
						|  | kernel_size | 
					
						
						|  | depthwise_multiplier: int | 
					
						
						|  | number of input_dim channels duplication. this value | 
					
						
						|  | will be used to compute the hidden channels of the Conv1D. | 
					
						
						|  | padding: int, optional | 
					
						
						|  | padding for the conv1d, | 
					
						
						|  | default: 0. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | input_dim, | 
					
						
						|  | depthwise_seperable_out_channel, | 
					
						
						|  | kernel_size, | 
					
						
						|  | depthwise_multiplier, | 
					
						
						|  | padding=0, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.dw_conv = nn.Conv1d( | 
					
						
						|  | input_dim, | 
					
						
						|  | input_dim * depthwise_multiplier, | 
					
						
						|  | kernel_size, | 
					
						
						|  | 1, | 
					
						
						|  | padding=padding, | 
					
						
						|  | groups=input_dim, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if depthwise_seperable_out_channel != 0: | 
					
						
						|  | self.pw_conv = nn.Conv1d( | 
					
						
						|  | input_dim * depthwise_multiplier, depthwise_seperable_out_channel, 1, 1, 0 | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.pw_conv = nn.Identity() | 
					
						
						|  | self.depthwise_seperable_out_channel = depthwise_seperable_out_channel | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | input tensor | 
					
						
						|  | """ | 
					
						
						|  | x = self.dw_conv(x) | 
					
						
						|  | if self.depthwise_seperable_out_channel != 0: | 
					
						
						|  | x = self.pw_conv(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ConvModule(nn.Module): | 
					
						
						|  | """ConvModule Module for the conformer block. | 
					
						
						|  | for more details see: | 
					
						
						|  | https://arxiv.org/pdf/2005.08100v1.pdf | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input_dim: int | 
					
						
						|  | input channel size. | 
					
						
						|  | ext_pw_out_channel: int | 
					
						
						|  | if > 0, ext_pw_out_channel is a dim channel size | 
					
						
						|  | for the last pointwise conv after swish activation. | 
					
						
						|  | depthwise_seperable_out_channel: int | 
					
						
						|  | if set different to 0, the number of depthwise_seperable_out_channel | 
					
						
						|  | will be used as a channel_out of the second conv1d layer. | 
					
						
						|  | otherwise, it equal to 0, the second conv1d layer is skipped. | 
					
						
						|  | ext_pw_kernel_size: int | 
					
						
						|  | kernel size of the conv pointwise of the conformer. | 
					
						
						|  | kernel_size: int | 
					
						
						|  | kernel size. | 
					
						
						|  | depthwise_multiplier: int | 
					
						
						|  | number of input_dim channels duplication. this value | 
					
						
						|  | will be used to compute the hidden channels of the Conv1D. | 
					
						
						|  | dropout_rate: float | 
					
						
						|  | dropout rate. | 
					
						
						|  | causal: bool, optional | 
					
						
						|  | if set to True, convolution have no access | 
					
						
						|  | to future frames. default False. | 
					
						
						|  | batch_norm: bool, optional | 
					
						
						|  | if set to True, apply batchnorm before activation. | 
					
						
						|  | default False | 
					
						
						|  | chunk_se: int, optional | 
					
						
						|  | 0 for offline SE. | 
					
						
						|  | 1 for streaming SE, where mean is computed | 
					
						
						|  | by accumulated history until current chunk_se. | 
					
						
						|  | 2 for streaming SE, where mean is computed | 
					
						
						|  | by only the current chunk. | 
					
						
						|  | chunk_size: int, optional | 
					
						
						|  | chunk size for cnn. default 18 | 
					
						
						|  | activation: str, optional | 
					
						
						|  | activation function used in ConvModule, | 
					
						
						|  | default: "relu". | 
					
						
						|  | glu_type: str, optional | 
					
						
						|  | activation function used for the glu, | 
					
						
						|  | default: "sigmoid". | 
					
						
						|  | bias_in_glu: bool, optional | 
					
						
						|  | if set to True, use additive bias in the weight module | 
					
						
						|  | before GLU. | 
					
						
						|  | linear_glu_in_convm: bool, optional | 
					
						
						|  | if set to True, use GLULinear module, | 
					
						
						|  | otherwise, used GLUPointWiseConv module. | 
					
						
						|  | default to False. | 
					
						
						|  | export: bool, optional, | 
					
						
						|  | if set to True, padding is equal to 0.  This is for inference, | 
					
						
						|  | or onnx export.  Typically this is set by the export program or | 
					
						
						|  | the decoder program, and it isn't present in your config file. | 
					
						
						|  | default False | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | input_dim, | 
					
						
						|  | ext_pw_out_channel, | 
					
						
						|  | depthwise_seperable_out_channel, | 
					
						
						|  | ext_pw_kernel_size, | 
					
						
						|  | kernel_size, | 
					
						
						|  | depthwise_multiplier, | 
					
						
						|  | dropout_rate, | 
					
						
						|  | causal=False, | 
					
						
						|  | batch_norm=False, | 
					
						
						|  | chunk_se=0, | 
					
						
						|  | chunk_size=18, | 
					
						
						|  | activation="relu", | 
					
						
						|  | glu_type="sigmoid", | 
					
						
						|  | bias_in_glu=True, | 
					
						
						|  | linear_glu_in_convm=False, | 
					
						
						|  | export=False, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.layer_norm = nn.LayerNorm(input_dim) | 
					
						
						|  | self.input_dim = input_dim | 
					
						
						|  | self.ext_pw_out_channel = ext_pw_out_channel | 
					
						
						|  | self.ext_pw_kernel_size = ext_pw_kernel_size | 
					
						
						|  | self.depthwise_seperable_out_channel = depthwise_seperable_out_channel | 
					
						
						|  | self.glu_type = glu_type | 
					
						
						|  | self.bias_in_glu = bias_in_glu | 
					
						
						|  | self.linear_glu_in_convm = linear_glu_in_convm | 
					
						
						|  | self.causal = causal | 
					
						
						|  |  | 
					
						
						|  | self._add_ext_pw_layer() | 
					
						
						|  |  | 
					
						
						|  | self.batch_norm = batch_norm | 
					
						
						|  | self.kernel_size = kernel_size | 
					
						
						|  |  | 
					
						
						|  | if batch_norm: | 
					
						
						|  | self.bn_layer = nn.BatchNorm1d(input_dim) | 
					
						
						|  |  | 
					
						
						|  | self.act = get_activation(activation) | 
					
						
						|  | self.dropout = nn.Dropout(dropout_rate) | 
					
						
						|  | self.export = export | 
					
						
						|  |  | 
					
						
						|  | if causal: | 
					
						
						|  | if export: | 
					
						
						|  | padding = 0 | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | padding = kernel_size - 1 | 
					
						
						|  | else: | 
					
						
						|  | padding = (kernel_size - 1) // 2 | 
					
						
						|  |  | 
					
						
						|  | self.dw_sep_conv_1d = DepthWiseSeperableConv1d( | 
					
						
						|  | input_dim, | 
					
						
						|  | depthwise_seperable_out_channel, | 
					
						
						|  | kernel_size, | 
					
						
						|  | depthwise_multiplier, | 
					
						
						|  | padding=padding, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if depthwise_seperable_out_channel != 0: | 
					
						
						|  | if input_dim != depthwise_seperable_out_channel: | 
					
						
						|  | self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim) | 
					
						
						|  | else: | 
					
						
						|  | if depthwise_multiplier != 1: | 
					
						
						|  | self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim) | 
					
						
						|  |  | 
					
						
						|  | def _add_ext_pw_layer(self): | 
					
						
						|  | """ | 
					
						
						|  | This function is an extension of __init__ function | 
					
						
						|  | and dedicated to the convolution module creation | 
					
						
						|  | of the conformer. | 
					
						
						|  | """ | 
					
						
						|  | self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = nn.Identity() | 
					
						
						|  | self.squeeze_excitation = nn.Identity() | 
					
						
						|  | self.apply_ln1 = self.fix_len1 = False | 
					
						
						|  |  | 
					
						
						|  | if self.ext_pw_out_channel != 0: | 
					
						
						|  | if self.causal: | 
					
						
						|  | self.ext_pw_conv_1d = nn.Conv1d( | 
					
						
						|  | self.input_dim, | 
					
						
						|  | self.ext_pw_out_channel, | 
					
						
						|  | self.ext_pw_kernel_size, | 
					
						
						|  | 1, | 
					
						
						|  | padding=(self.ext_pw_kernel_size - 1), | 
					
						
						|  | ) | 
					
						
						|  | if self.ext_pw_kernel_size > 1: | 
					
						
						|  | self.fix_len1 = True | 
					
						
						|  | else: | 
					
						
						|  | self.fix_len1 = False | 
					
						
						|  | else: | 
					
						
						|  | self.ext_pw_conv_1d = nn.Conv1d( | 
					
						
						|  | self.input_dim, | 
					
						
						|  | self.ext_pw_out_channel, | 
					
						
						|  | self.ext_pw_kernel_size, | 
					
						
						|  | 1, | 
					
						
						|  | padding=(self.ext_pw_kernel_size - 1) // 2, | 
					
						
						|  | ) | 
					
						
						|  | self.fix_len1 = False | 
					
						
						|  |  | 
					
						
						|  | if self.linear_glu_in_convm: | 
					
						
						|  | self.glu = GLULinear( | 
					
						
						|  | self.input_dim, self.ext_pw_out_channel, self.glu_type, self.bias_in_glu | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.glu = GLUPointWiseConv( | 
					
						
						|  | self.input_dim, | 
					
						
						|  | self.ext_pw_out_channel, | 
					
						
						|  | self.ext_pw_kernel_size, | 
					
						
						|  | self.glu_type, | 
					
						
						|  | self.bias_in_glu, | 
					
						
						|  | self.causal, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if self.input_dim != self.ext_pw_out_channel: | 
					
						
						|  | self.apply_ln1 = True | 
					
						
						|  | self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim) | 
					
						
						|  | else: | 
					
						
						|  | self.apply_ln1 = False | 
					
						
						|  | else: | 
					
						
						|  | self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) | 
					
						
						|  | self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """ConvModule Forward. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | input tensor. | 
					
						
						|  | """ | 
					
						
						|  | x = self.layer_norm(x) | 
					
						
						|  |  | 
					
						
						|  | if self.ext_pw_out_channel != 0: | 
					
						
						|  | x = self.glu(x) | 
					
						
						|  | if self.causal and self.ext_pw_kernel_size > 1: | 
					
						
						|  | x = x[:, : -(self.ext_pw_kernel_size - 1), :] | 
					
						
						|  | if self.apply_ln1: | 
					
						
						|  | x = self.ln1(x) | 
					
						
						|  | else: | 
					
						
						|  | x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0] | 
					
						
						|  | x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1] | 
					
						
						|  | x = x_0 + x_1 | 
					
						
						|  |  | 
					
						
						|  | x = x.permute([0, 2, 1]) | 
					
						
						|  |  | 
					
						
						|  | x = self.dw_sep_conv_1d(x) | 
					
						
						|  | if self.causal and self.kernel_size > 1: | 
					
						
						|  | x = x[:, :, : -(self.kernel_size - 1)] | 
					
						
						|  | if hasattr(self, "ln2"): | 
					
						
						|  | x = x.permute([0, 2, 1]) | 
					
						
						|  | x = self.ln2(x) | 
					
						
						|  | x = x.permute([0, 2, 1]) | 
					
						
						|  | if self.batch_norm: | 
					
						
						|  | x = self.bn_layer(x) | 
					
						
						|  | x = self.act(x) | 
					
						
						|  |  | 
					
						
						|  | if self.ext_pw_out_channel != 0: | 
					
						
						|  | x = self.ext_pw_conv_1d(x) | 
					
						
						|  | if self.fix_len1: | 
					
						
						|  | x = x[:, :, : -(self.ext_pw_kernel_size - 1)] | 
					
						
						|  |  | 
					
						
						|  | if self.apply_ln1: | 
					
						
						|  | x = x.permute([0, 2, 1]) | 
					
						
						|  | x = self.ln1(x) | 
					
						
						|  | x = x.permute([0, 2, 1]) | 
					
						
						|  |  | 
					
						
						|  | x = x.permute([0, 2, 1]) | 
					
						
						|  | else: | 
					
						
						|  | x = x.unsqueeze(1).permute([0, 1, 3, 2]) | 
					
						
						|  | x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2] | 
					
						
						|  | x = x.squeeze(1) | 
					
						
						|  |  | 
					
						
						|  | x = self.dropout(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | class GLULinear(nn.Module): | 
					
						
						|  | """Linear + GLU module | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input_dim: int | 
					
						
						|  | input size | 
					
						
						|  | output_dim: int | 
					
						
						|  | output size. | 
					
						
						|  | glu_type: | 
					
						
						|  | activation function name used in glu module. | 
					
						
						|  | default "sigmoid" (swish function). | 
					
						
						|  | bias_in_glu: bool, optional | 
					
						
						|  | If True, the addtive bias is added. Default False. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | input_dim, | 
					
						
						|  | output_dim, | 
					
						
						|  | glu_type="sigmoid", | 
					
						
						|  | bias_in_glu=True, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) | 
					
						
						|  | self.glu_act = GLU(-1, glu_type) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """GLULinear forward | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | inpute tensor. | 
					
						
						|  | """ | 
					
						
						|  | x = self.linear(x) | 
					
						
						|  | return self.glu_act(x) | 
					
						
						|  |  | 
					
						
						|  | class FeedForward(nn.Module): | 
					
						
						|  | """FeedForward Module. | 
					
						
						|  | For more details see Conformer paper: | 
					
						
						|  | https://arxiv.org/pdf/2005.08100.pdf | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | d_model: int | 
					
						
						|  | input size. | 
					
						
						|  | d_inner: int | 
					
						
						|  | output size. | 
					
						
						|  | dropout_rate: float, | 
					
						
						|  | dropout rate. | 
					
						
						|  | activation: str, | 
					
						
						|  | activation function name, | 
					
						
						|  | one of ["relu", "swish", "sigmoid"], | 
					
						
						|  | sigmoid activation is only used with "glu_in_fnn=True", | 
					
						
						|  | default "sigmoid". | 
					
						
						|  | bias_in_glu: bool, optional | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | d_model, | 
					
						
						|  | d_inner, | 
					
						
						|  | dropout_rate, | 
					
						
						|  | activation="sigmoid", | 
					
						
						|  | bias_in_glu=True, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.d_model = d_model | 
					
						
						|  | self.d_inner = d_inner | 
					
						
						|  |  | 
					
						
						|  | self.layer_norm = nn.LayerNorm(d_model) | 
					
						
						|  | module = GLULinear(d_model, d_inner, activation, bias_in_glu) | 
					
						
						|  | self.net = nn.Sequential( | 
					
						
						|  | module, | 
					
						
						|  | nn.Dropout(dropout_rate), | 
					
						
						|  | nn.Linear(d_inner, d_model), | 
					
						
						|  | nn.Dropout(dropout_rate), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | """FeedForward forward function. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | input tensor. | 
					
						
						|  | """ | 
					
						
						|  | out = self.net(self.layer_norm(x)) | 
					
						
						|  |  | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _pre_hook( | 
					
						
						|  | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | 
					
						
						|  | ): | 
					
						
						|  | """Perform pre-hook in load_state_dict for backward compatibility. | 
					
						
						|  |  | 
					
						
						|  | Note: | 
					
						
						|  | We saved self.pe until v.0.5.2 but we have omitted it later. | 
					
						
						|  | Therefore, we remove the item "pe" from `state_dict` for backward compatibility. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | k = prefix + "pe" | 
					
						
						|  | if k in state_dict: | 
					
						
						|  | state_dict.pop(k) | 
					
						
						|  |  | 
					
						
						|  | class T5RelativeAttentionLogitBias(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | This module implements the relative position bias described in Section 2.1 of | 
					
						
						|  | the T5 paper: https://arxiv.org/pdf/1910.10683.pdf | 
					
						
						|  |  | 
					
						
						|  | The Huggingface implementation is used as a reference | 
					
						
						|  | https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/t5/modeling_t5.py#L435 | 
					
						
						|  |  | 
					
						
						|  | Modifies attention as Q*K^T + B, where B is a learned scalar bias based on relative position | 
					
						
						|  | of the query and key. It is HxNxN, where H is the number of heads, N is the sequence length. | 
					
						
						|  |  | 
					
						
						|  | I've made these modifications to the original T5 bias: | 
					
						
						|  | - Skipping of the bucketing step. Original T5 bias converted rel position distances into | 
					
						
						|  | logarithmically increasing buckets. This is supposed to help with length generalization. | 
					
						
						|  | - I just directly use rel position index as bias values, as we don't need length | 
					
						
						|  | generalization (40s max is good enough for ASR encoder), and it keeps ONNX export simple. | 
					
						
						|  | - I've also extended it so that biases can be asymmetric, the default implementation treats | 
					
						
						|  | L->R and R->L the same. Asymmetric was found to yield better results in my experiments. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | num_heads: int | 
					
						
						|  | Number of attention heads | 
					
						
						|  | num_buckets: int | 
					
						
						|  | Number of buckets to use for relative attention bias. This is the size of the learnable | 
					
						
						|  | bias parameter. Bucketing is not yet supported, so this defaults to -1 which means | 
					
						
						|  | no bucketing is used (max_distance determines size of bias param). | 
					
						
						|  | max_distance: int | 
					
						
						|  | Maximum distance to use for relative attention bias. With num_buckets=-1, this directly | 
					
						
						|  | controls the max size of the bias parameter. When num_buckets > 0 is supported, this | 
					
						
						|  | will control the maximum distance for logarithmic bucketing after which all positions | 
					
						
						|  | are in the same bucket. | 
					
						
						|  | symmetric: bool | 
					
						
						|  | Whether to use symmetric or asymmetric biases. symmetric=False uses 2x number of bias | 
					
						
						|  | params to distinguish L->R from R->L. This was found to be better for the encoder. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.num_heads = num_heads | 
					
						
						|  | self.num_buckets = num_buckets | 
					
						
						|  | self.max_distance = max_distance | 
					
						
						|  | self.symmetric = symmetric | 
					
						
						|  | self._skip_bucketing = self.num_buckets < 0 | 
					
						
						|  | if self._skip_bucketing: | 
					
						
						|  | self.num_buckets = max_distance | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError("T5 attention bias with bucketed positions is not yet tested") | 
					
						
						|  | if not self.symmetric: | 
					
						
						|  | self.num_buckets *= 2 | 
					
						
						|  | self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  |  | 
					
						
						|  | maxpos = x.size(1) | 
					
						
						|  | context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[:, None] | 
					
						
						|  | memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[None, :] | 
					
						
						|  | relative_position = memory_position - context_position | 
					
						
						|  |  | 
					
						
						|  | relative_position = relative_position.masked_fill( | 
					
						
						|  | relative_position < -self.max_distance, -self.max_distance | 
					
						
						|  | ) | 
					
						
						|  | relative_position = relative_position.masked_fill( | 
					
						
						|  | relative_position > self.max_distance - 1, self.max_distance - 1 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self._skip_bucketing: | 
					
						
						|  | bias_idx = relative_position | 
					
						
						|  | else: | 
					
						
						|  | bias_idx = self._bucket_relative_position(relative_position) | 
					
						
						|  | if self.symmetric: | 
					
						
						|  | bias_idx = bias_idx.abs() | 
					
						
						|  | else: | 
					
						
						|  | bias_idx += self.num_buckets // 2 | 
					
						
						|  |  | 
					
						
						|  | t5_rel_att_bias = self.bias_values(bias_idx) | 
					
						
						|  | t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) | 
					
						
						|  |  | 
					
						
						|  | return t5_rel_att_bias | 
					
						
						|  |  | 
					
						
						|  | def _bucket_relative_position(self, relative_position): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | relative_buckets = 0 | 
					
						
						|  | if not self.causal: | 
					
						
						|  | num_buckets //= 2 | 
					
						
						|  | relative_buckets += (relative_position > 0).to(torch.long) * num_buckets | 
					
						
						|  | relative_position = torch.abs(relative_position) | 
					
						
						|  | else: | 
					
						
						|  | relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | max_exact = num_buckets // 2 | 
					
						
						|  | is_small = relative_position < max_exact | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | relative_position_if_large = max_exact + ( | 
					
						
						|  | torch.log(relative_position.float() / max_exact) | 
					
						
						|  | / math.log(self.max_distance / max_exact) | 
					
						
						|  | * (num_buckets - max_exact) | 
					
						
						|  | ).to(torch.long) | 
					
						
						|  | relative_position_if_large = torch.min( | 
					
						
						|  | relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) | 
					
						
						|  | return relative_buckets | 
					
						
						|  |  | 
					
						
						|  | class AbsolutePositionalEncoding(nn.Module): | 
					
						
						|  | """Absolute Positional encoding module. | 
					
						
						|  | This module implement Absolute sinusoidal positional encoding | 
					
						
						|  | from: https://arxiv.org/pdf/1706.03762.pdf | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | d_model: int | 
					
						
						|  | Input embedding size. | 
					
						
						|  | dropout_rate: float | 
					
						
						|  | dropout rate | 
					
						
						|  | max_len: int, optional | 
					
						
						|  | Maximum input length sequence, Default 5000 | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, d_model, dropout_rate, max_len=5000): | 
					
						
						|  | """Construct an PositionalEncoding object.""" | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.d_model = d_model | 
					
						
						|  | self.xscale = math.sqrt(self.d_model) | 
					
						
						|  | self.dropout = torch.nn.Dropout(p=dropout_rate) | 
					
						
						|  | self.pe = None | 
					
						
						|  | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) | 
					
						
						|  | self._register_load_state_dict_pre_hook(_pre_hook) | 
					
						
						|  |  | 
					
						
						|  | def extend_pe(self, x): | 
					
						
						|  | """Reset the positional encodings. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | """ | 
					
						
						|  | if self.pe is not None: | 
					
						
						|  | if self.pe.size(1) >= x.size(1): | 
					
						
						|  | if self.pe.dtype != x.dtype or self.pe.device != x.device: | 
					
						
						|  | self.pe = self.pe.to(dtype=x.dtype, device=x.device) | 
					
						
						|  | return | 
					
						
						|  | pe = torch.zeros(x.size(1), self.d_model) | 
					
						
						|  | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) | 
					
						
						|  | div_term = torch.exp( | 
					
						
						|  | torch.arange(0, self.d_model, 2, dtype=torch.float32) | 
					
						
						|  | * -(math.log(10000.0) / self.d_model) | 
					
						
						|  | ) | 
					
						
						|  | pe[:, 0::2] = torch.sin(position * div_term) | 
					
						
						|  | pe[:, 1::2] = torch.cos(position * div_term) | 
					
						
						|  | pe = pe.unsqueeze(0) | 
					
						
						|  | self.pe = pe.to(device=x.device, dtype=x.dtype) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor): | 
					
						
						|  | """Add positional encoding. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | Input tensor. shape is (batch, time, ...) | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | self.extend_pe(x) | 
					
						
						|  | x = x * self.xscale + self.pe[:, : x.size(1)] | 
					
						
						|  | return self.dropout(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @backoff.on_exception(backoff.expo, Exception, max_tries=10) | 
					
						
						|  | def np_loadtxt_with_retry(filepath): | 
					
						
						|  | """np.loadtxt with retry | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | filepath: str | 
					
						
						|  | file path to the numpy array. | 
					
						
						|  | """ | 
					
						
						|  | result = np.loadtxt(filepath, dtype="f") | 
					
						
						|  | return result | 
					
						
						|  |  | 
					
						
						|  | class MeanVarianceNormLayer(nn.Module): | 
					
						
						|  | """Mean/variance normalization layer. | 
					
						
						|  |  | 
					
						
						|  | Will substract mean and multiply input by inverted standard deviation. | 
					
						
						|  | Typically used as a very first layer in a model. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input_size: int | 
					
						
						|  | layer input size. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, input_size): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.input_size = input_size | 
					
						
						|  | self.register_buffer("global_mean", torch.zeros(input_size)) | 
					
						
						|  | self.register_buffer("global_invstd", torch.ones(input_size)) | 
					
						
						|  | self.global_mean: Optional[Tensor] | 
					
						
						|  | self.global_invstd: Optional[Tensor] | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input_: Tensor) -> Tensor: | 
					
						
						|  | """MeanVarianceNormLayer Forward | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input_: torch.Tensor | 
					
						
						|  | input tensor. | 
					
						
						|  | """ | 
					
						
						|  | return (input_ - self.global_mean) * self.global_invstd | 
					
						
						|  |  | 
					
						
						|  | def load_mean_invstd(self, mean_file, invstd_file, cuside_features=False): | 
					
						
						|  | """Load feature mean and variance used for normalization. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | mean_file: str | 
					
						
						|  | path to the feature mean statistics file. | 
					
						
						|  | invstd_file: str | 
					
						
						|  | path to the features inverted standard deviation | 
					
						
						|  | statistics file. | 
					
						
						|  | cuside_features: bool | 
					
						
						|  | Boolean that indicates CUSIDE is being used. | 
					
						
						|  | The statistics of CUSIDE features are copied | 
					
						
						|  | from the normal features | 
					
						
						|  | """ | 
					
						
						|  | self.global_mean.data = torch.from_numpy(np_loadtxt_with_retry(mean_file)) | 
					
						
						|  | self.global_invstd.data = torch.from_numpy(np_loadtxt_with_retry(invstd_file)) | 
					
						
						|  |  | 
					
						
						|  | if cuside_features: | 
					
						
						|  | self.global_mean.data = torch.cat((self.global_mean.data, self.global_mean.data), 0) | 
					
						
						|  | self.global_invstd.data = torch.cat( | 
					
						
						|  | (self.global_invstd.data, self.global_invstd.data), 0 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | class CausalConv1D(nn.Conv1d): | 
					
						
						|  | """ | 
					
						
						|  | A causal version of nn.Conv1d where each step would have limited access to locations on its right or left | 
					
						
						|  | All arguments are the same as nn.Conv1d except padding. | 
					
						
						|  |  | 
					
						
						|  | If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right. | 
					
						
						|  |  | 
					
						
						|  | If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding. | 
					
						
						|  | It would make it possible to control the number of steps to be accessible on the right and left. | 
					
						
						|  | This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1). | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_channels: int, | 
					
						
						|  | out_channels: int, | 
					
						
						|  | kernel_size: int, | 
					
						
						|  | stride: int = 1, | 
					
						
						|  | padding: Union[str, int] = 0, | 
					
						
						|  | dilation: int = 1, | 
					
						
						|  | groups: int = 1, | 
					
						
						|  | bias: bool = True, | 
					
						
						|  | padding_mode: str = "zeros", | 
					
						
						|  | device=None, | 
					
						
						|  | dtype=None, | 
					
						
						|  | ) -> None: | 
					
						
						|  | self.cache_drop_size = None | 
					
						
						|  | if padding is None: | 
					
						
						|  | self._left_padding = kernel_size - 1 | 
					
						
						|  | self._right_padding = stride - 1 | 
					
						
						|  | else: | 
					
						
						|  | if stride != 1 and padding != kernel_size - 1: | 
					
						
						|  | raise ValueError("No striding allowed for non-symmetric convolutions!") | 
					
						
						|  | if isinstance(padding, int): | 
					
						
						|  | self._left_padding = padding | 
					
						
						|  | self._right_padding = padding | 
					
						
						|  | elif ( | 
					
						
						|  | isinstance(padding, list) | 
					
						
						|  | and len(padding) == 2 | 
					
						
						|  | and padding[0] + padding[1] == kernel_size - 1 | 
					
						
						|  | ): | 
					
						
						|  | self._left_padding = padding[0] | 
					
						
						|  | self._right_padding = padding[1] | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Invalid padding param: {padding}!") | 
					
						
						|  |  | 
					
						
						|  | self._max_cache_len = self._left_padding | 
					
						
						|  |  | 
					
						
						|  | super().__init__( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=out_channels, | 
					
						
						|  | kernel_size=kernel_size, | 
					
						
						|  | stride=stride, | 
					
						
						|  | padding=0, | 
					
						
						|  | dilation=dilation, | 
					
						
						|  | groups=groups, | 
					
						
						|  | bias=bias, | 
					
						
						|  | padding_mode=padding_mode, | 
					
						
						|  | device=device, | 
					
						
						|  | dtype=dtype, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def update_cache(self, x, cache=None): | 
					
						
						|  | if cache is None: | 
					
						
						|  | new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) | 
					
						
						|  | next_cache = cache | 
					
						
						|  | else: | 
					
						
						|  | new_x = F.pad(x, pad=(0, self._right_padding)) | 
					
						
						|  | new_x = torch.cat([cache, new_x], dim=-1) | 
					
						
						|  | if self.cache_drop_size > 0: | 
					
						
						|  | next_cache = new_x[:, :, : -self.cache_drop_size] | 
					
						
						|  | else: | 
					
						
						|  | next_cache = new_x | 
					
						
						|  | next_cache = next_cache[:, :, -cache.size(-1) :] | 
					
						
						|  | return new_x, next_cache | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, cache=None): | 
					
						
						|  | x, cache = self.update_cache(x, cache=cache) | 
					
						
						|  | x = super().forward(x) | 
					
						
						|  | if cache is None: | 
					
						
						|  | return x | 
					
						
						|  | else: | 
					
						
						|  | return x, cache | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CausalConv2D(nn.Conv2d): | 
					
						
						|  | """ | 
					
						
						|  | A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down | 
					
						
						|  | All arguments are the same as nn.Conv2d except padding which should be set as None | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_channels: int, | 
					
						
						|  | out_channels: int, | 
					
						
						|  | kernel_size: int, | 
					
						
						|  | stride: int = 1, | 
					
						
						|  | padding: Union[str, int] = 0, | 
					
						
						|  | dilation: int = 1, | 
					
						
						|  | groups: int = 1, | 
					
						
						|  | bias: bool = True, | 
					
						
						|  | padding_mode: str = "zeros", | 
					
						
						|  | device=None, | 
					
						
						|  | dtype=None, | 
					
						
						|  | ) -> None: | 
					
						
						|  | if padding is not None: | 
					
						
						|  | raise ValueError("Argument padding should be set to None for CausalConv2D.") | 
					
						
						|  | self._left_padding = kernel_size - 1 | 
					
						
						|  | self._right_padding = stride - 1 | 
					
						
						|  |  | 
					
						
						|  | padding = 0 | 
					
						
						|  | super().__init__( | 
					
						
						|  | in_channels, | 
					
						
						|  | out_channels, | 
					
						
						|  | kernel_size, | 
					
						
						|  | stride, | 
					
						
						|  | padding, | 
					
						
						|  | dilation, | 
					
						
						|  | groups, | 
					
						
						|  | bias, | 
					
						
						|  | padding_mode, | 
					
						
						|  | device, | 
					
						
						|  | dtype, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | x, | 
					
						
						|  | ): | 
					
						
						|  | if self.training: | 
					
						
						|  | x = F.pad( | 
					
						
						|  | x, | 
					
						
						|  | pad=( | 
					
						
						|  | self._left_padding, | 
					
						
						|  | self._right_padding, | 
					
						
						|  | self._left_padding, | 
					
						
						|  | self._right_padding, | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | x = F.pad( | 
					
						
						|  | x, | 
					
						
						|  | pad=(self._left_padding, self._right_padding, 0, 0), | 
					
						
						|  | ) | 
					
						
						|  | x = super().forward(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class NemoConvSubsampling(torch.nn.Module): | 
					
						
						|  | """Convlutional subsampling module, taken from NeMo ASR | 
					
						
						|  | (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) | 
					
						
						|  |  | 
					
						
						|  | Striding Subsampling: "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for | 
					
						
						|  | Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | Compared with the EncoderConv2D (`input_layer: custom`), this is a much simplified approach, | 
					
						
						|  | and uses no LayerNorm and far fewer Conv2Ds.  Moreover, depthwise convolutions are used to reduce | 
					
						
						|  | FLOPs, but the first layer is kept as a regular convolution so as not to degrade accuracy. | 
					
						
						|  |  | 
					
						
						|  | `Striding` and `dw_striding` are the same except that the latter uses depthwise convolutions | 
					
						
						|  | after the first layer, whereas the former does not. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | subsampling_factor (int): Time reduction factor | 
					
						
						|  | feat_in (int): size of the input features | 
					
						
						|  | feat_out (int): size of the output features | 
					
						
						|  | subsampling (str): The subsampling technique, choose from | 
					
						
						|  | {"striding", "dw-striding", "striding_conv1d", "dw_striding_conv1d"} | 
					
						
						|  | conv_channels (int): Number of channels for the convolution layers, default is 256. | 
					
						
						|  | subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking) | 
					
						
						|  | 1 (auto) or a power of 2. Default is 1 | 
					
						
						|  | activation (Module): activation function, default is nn.ReLU() | 
					
						
						|  | is_causal (bool): whether to use causal Conv1/2D, where each step will have limited access | 
					
						
						|  | to locations on its right or left | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | feat_in, | 
					
						
						|  | feat_out, | 
					
						
						|  | subsampling_factor=4, | 
					
						
						|  | subsampling="dw_striding", | 
					
						
						|  | conv_channels=256, | 
					
						
						|  | subsampling_conv_chunking_factor=1, | 
					
						
						|  | activation=nn.ReLU(), | 
					
						
						|  | is_causal=False, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self._subsampling = subsampling | 
					
						
						|  | self._conv_channels = conv_channels | 
					
						
						|  | self._feat_in = feat_in | 
					
						
						|  | self._feat_out = feat_out | 
					
						
						|  |  | 
					
						
						|  | if subsampling_factor % 2 != 0: | 
					
						
						|  | raise ValueError("Sampling factor should be a multiply of 2!") | 
					
						
						|  | self._sampling_num = int(math.log(subsampling_factor, 2)) | 
					
						
						|  | self.subsampling_factor = subsampling_factor | 
					
						
						|  | self.is_causal = is_causal | 
					
						
						|  | self.subsampling_causal_cond = subsampling in ("dw_striding", "striding", "striding_conv1d") | 
					
						
						|  |  | 
					
						
						|  | if ( | 
					
						
						|  | subsampling_conv_chunking_factor != -1 | 
					
						
						|  | and subsampling_conv_chunking_factor != 1 | 
					
						
						|  | and subsampling_conv_chunking_factor % 2 != 0 | 
					
						
						|  | ): | 
					
						
						|  | raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2") | 
					
						
						|  | self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor | 
					
						
						|  |  | 
					
						
						|  | in_channels = 1 | 
					
						
						|  | layers = [] | 
					
						
						|  |  | 
					
						
						|  | if subsampling == "dw_striding": | 
					
						
						|  | self._stride = 2 | 
					
						
						|  | self._kernel_size = 3 | 
					
						
						|  | self._ceil_mode = False | 
					
						
						|  |  | 
					
						
						|  | if self.is_causal: | 
					
						
						|  | self._left_padding = self._kernel_size - 1 | 
					
						
						|  | self._right_padding = self._stride - 1 | 
					
						
						|  | self._max_cache_len = subsampling_factor + 1 | 
					
						
						|  | else: | 
					
						
						|  | self._left_padding = (self._kernel_size - 1) // 2 | 
					
						
						|  | self._right_padding = (self._kernel_size - 1) // 2 | 
					
						
						|  | self._max_cache_len = 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.is_causal: | 
					
						
						|  | layers.append( | 
					
						
						|  | CausalConv2D( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=conv_channels, | 
					
						
						|  | kernel_size=self._kernel_size, | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=None, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | layers.append( | 
					
						
						|  | torch.nn.Conv2d( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=conv_channels, | 
					
						
						|  | kernel_size=self._kernel_size, | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=self._left_padding, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | in_channels = conv_channels | 
					
						
						|  | layers.append(activation) | 
					
						
						|  |  | 
					
						
						|  | for i in range(self._sampling_num - 1): | 
					
						
						|  | if self.is_causal: | 
					
						
						|  | layers.append( | 
					
						
						|  | CausalConv2D( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=in_channels, | 
					
						
						|  | kernel_size=self._kernel_size, | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=None, | 
					
						
						|  | groups=in_channels, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | layers.append( | 
					
						
						|  | torch.nn.Conv2d( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=in_channels, | 
					
						
						|  | kernel_size=self._kernel_size, | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=self._left_padding, | 
					
						
						|  | groups=in_channels, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | layers.append( | 
					
						
						|  | torch.nn.Conv2d( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=conv_channels, | 
					
						
						|  | kernel_size=1, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=0, | 
					
						
						|  | groups=1, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | layers.append(activation) | 
					
						
						|  | in_channels = conv_channels | 
					
						
						|  |  | 
					
						
						|  | elif subsampling == "striding": | 
					
						
						|  | self._stride = 2 | 
					
						
						|  | self._kernel_size = 3 | 
					
						
						|  | self._ceil_mode = False | 
					
						
						|  |  | 
					
						
						|  | if self.is_causal: | 
					
						
						|  | self._left_padding = self._kernel_size - 1 | 
					
						
						|  | self._right_padding = self._stride - 1 | 
					
						
						|  | self._max_cache_len = subsampling_factor + 1 | 
					
						
						|  | else: | 
					
						
						|  | self._left_padding = (self._kernel_size - 1) // 2 | 
					
						
						|  | self._right_padding = (self._kernel_size - 1) // 2 | 
					
						
						|  | self._max_cache_len = 0 | 
					
						
						|  |  | 
					
						
						|  | for i in range(self._sampling_num): | 
					
						
						|  | if self.is_causal: | 
					
						
						|  | layers.append( | 
					
						
						|  | CausalConv2D( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=conv_channels, | 
					
						
						|  | kernel_size=self._kernel_size, | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=None, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | layers.append( | 
					
						
						|  | torch.nn.Conv2d( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=conv_channels, | 
					
						
						|  | kernel_size=self._kernel_size, | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=self._left_padding, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | layers.append(activation) | 
					
						
						|  | in_channels = conv_channels | 
					
						
						|  |  | 
					
						
						|  | elif subsampling == "striding_conv1d": | 
					
						
						|  | in_channels = feat_in | 
					
						
						|  |  | 
					
						
						|  | self._stride = 2 | 
					
						
						|  | self._kernel_size = 5 | 
					
						
						|  | self._ceil_mode = False | 
					
						
						|  |  | 
					
						
						|  | if self.is_causal: | 
					
						
						|  | self._left_padding = self._kernel_size - 1 | 
					
						
						|  | self._right_padding = self._stride - 1 | 
					
						
						|  | self._max_cache_len = subsampling_factor + 1 | 
					
						
						|  | else: | 
					
						
						|  | self._left_padding = (self._kernel_size - 1) // 2 | 
					
						
						|  | self._right_padding = (self._kernel_size - 1) // 2 | 
					
						
						|  | self._max_cache_len = 0 | 
					
						
						|  |  | 
					
						
						|  | for i in range(self._sampling_num): | 
					
						
						|  | if self.is_causal: | 
					
						
						|  | layers.append( | 
					
						
						|  | CausalConv1D( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=feat_out if self._sampling_num == i + 1 else conv_channels, | 
					
						
						|  | kernel_size=self._kernel_size, | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=None, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | layers.append( | 
					
						
						|  | torch.nn.Conv1d( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=feat_out if self._sampling_num == i + 1 else conv_channels, | 
					
						
						|  | kernel_size=self._kernel_size, | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=self._left_padding, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | layers.append(activation) | 
					
						
						|  | in_channels = conv_channels | 
					
						
						|  |  | 
					
						
						|  | elif subsampling == "dw_striding_conv1d": | 
					
						
						|  | in_channels = feat_in | 
					
						
						|  |  | 
					
						
						|  | self._stride = 2 | 
					
						
						|  | self._kernel_size = 5 | 
					
						
						|  | self._ceil_mode = False | 
					
						
						|  |  | 
					
						
						|  | self._left_padding = (self._kernel_size - 1) // 2 | 
					
						
						|  | self._right_padding = (self._kernel_size - 1) // 2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | layers.extend( | 
					
						
						|  | [ | 
					
						
						|  | torch.nn.Conv1d( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=in_channels, | 
					
						
						|  | kernel_size=self._kernel_size, | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=self._left_padding, | 
					
						
						|  | groups=in_channels, | 
					
						
						|  | ), | 
					
						
						|  | torch.nn.Conv1d( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=feat_out if self._sampling_num == 1 else conv_channels, | 
					
						
						|  | kernel_size=1, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=0, | 
					
						
						|  | groups=1, | 
					
						
						|  | ), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | in_channels = conv_channels | 
					
						
						|  | layers.append(activation) | 
					
						
						|  |  | 
					
						
						|  | for i in range(self._sampling_num - 1): | 
					
						
						|  | layers.extend( | 
					
						
						|  | [ | 
					
						
						|  | torch.nn.Conv1d( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=in_channels, | 
					
						
						|  | kernel_size=self._kernel_size, | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=self._left_padding, | 
					
						
						|  | groups=in_channels, | 
					
						
						|  | ), | 
					
						
						|  | torch.nn.Conv1d( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=feat_out if self._sampling_num == i + 2 else conv_channels, | 
					
						
						|  | kernel_size=1, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=0, | 
					
						
						|  | groups=1, | 
					
						
						|  | ), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | layers.append(activation) | 
					
						
						|  | in_channels = conv_channels | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Not valid sub-sampling: {subsampling}!") | 
					
						
						|  |  | 
					
						
						|  | if subsampling in ["dw_striding", "striding"]: | 
					
						
						|  | in_length = torch.tensor(feat_in, dtype=torch.float) | 
					
						
						|  | out_length = calc_length( | 
					
						
						|  | lengths=in_length, | 
					
						
						|  | all_paddings=self._left_padding + self._right_padding, | 
					
						
						|  | kernel_size=self._kernel_size, | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | ceil_mode=self._ceil_mode, | 
					
						
						|  | repeat_num=self._sampling_num, | 
					
						
						|  | ) | 
					
						
						|  | self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) | 
					
						
						|  | self.conv2d_subsampling = True | 
					
						
						|  | elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: | 
					
						
						|  | self.out = None | 
					
						
						|  | self.conv2d_subsampling = False | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Not valid sub-sampling: {subsampling}!") | 
					
						
						|  |  | 
					
						
						|  | self.conv = torch.nn.Sequential(*layers) | 
					
						
						|  |  | 
					
						
						|  | def get_sampling_frames(self): | 
					
						
						|  | return [1, self.subsampling_factor] | 
					
						
						|  |  | 
					
						
						|  | def get_streaming_cache_size(self): | 
					
						
						|  | return [0, self.subsampling_factor + 1] | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, mask): | 
					
						
						|  | """ | 
					
						
						|  | Forward method for NeMo subsampling. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x[Batch, Time, Filters]: torch.Tensor | 
					
						
						|  | input tensor | 
					
						
						|  | x_mask: torch.Tensor | 
					
						
						|  | input mask | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out) | 
					
						
						|  | pad_mask: torch.Tensor | 
					
						
						|  | tensor of padded hidden state sequences (B, 1, T // time_reduction_factor) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if self.conv2d_subsampling: | 
					
						
						|  | x = x.unsqueeze(1) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling: | 
					
						
						|  | if self.subsampling_conv_chunking_factor == 1: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_ceil = 2**31 / self._conv_channels * self._stride * self._stride | 
					
						
						|  | if torch.numel(x) > x_ceil: | 
					
						
						|  | need_to_split = True | 
					
						
						|  | else: | 
					
						
						|  | need_to_split = False | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | need_to_split = True | 
					
						
						|  |  | 
					
						
						|  | if need_to_split: | 
					
						
						|  | x, success = self.conv_split_by_batch(x) | 
					
						
						|  | if not success: | 
					
						
						|  | if self._subsampling == "dw_striding": | 
					
						
						|  | x = self.conv_split_by_channel(x) | 
					
						
						|  | else: | 
					
						
						|  | x = self.conv(x) | 
					
						
						|  | else: | 
					
						
						|  | x = self.conv(x) | 
					
						
						|  | else: | 
					
						
						|  | x = self.conv(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.conv2d_subsampling: | 
					
						
						|  | b, c, t, f = x.size() | 
					
						
						|  | x = self.out(x.transpose(1, 2).reshape(b, t, -1)) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | x = x.transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  | if mask is None: | 
					
						
						|  | return x, None | 
					
						
						|  |  | 
					
						
						|  | max_audio_length = x.shape[1] | 
					
						
						|  | feature_lens = mask.sum(1) | 
					
						
						|  | padding_length = torch.ceil(feature_lens / self.subsampling_factor) | 
					
						
						|  | if self.is_causal and self.subsampling_causal_cond: | 
					
						
						|  | feature_lens_remainder = feature_lens % self.subsampling_factor | 
					
						
						|  | padding_length[feature_lens_remainder != 1] += 1 | 
					
						
						|  | pad_mask = ( | 
					
						
						|  | torch.arange(0, max_audio_length, device=x.device).expand(padding_length.size(0), -1) | 
					
						
						|  | < padding_length.unsqueeze(1) | 
					
						
						|  | ) | 
					
						
						|  | return x, pad_mask.unsqueeze(1) | 
					
						
						|  |  | 
					
						
						|  | def reset_parameters(self): | 
					
						
						|  |  | 
					
						
						|  | if self._subsampling == "dw_striding": | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  |  | 
					
						
						|  | scale = 1.0 / self._kernel_size | 
					
						
						|  | dw_max = (self._kernel_size**2) ** -0.5 | 
					
						
						|  | pw_max = self._conv_channels**-0.5 | 
					
						
						|  |  | 
					
						
						|  | torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) | 
					
						
						|  | torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) | 
					
						
						|  |  | 
					
						
						|  | for idx in range(2, len(self.conv), 3): | 
					
						
						|  | torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max) | 
					
						
						|  | torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) | 
					
						
						|  | torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max) | 
					
						
						|  | torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5 | 
					
						
						|  | torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) | 
					
						
						|  | torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) | 
					
						
						|  |  | 
					
						
						|  | def conv_split_by_batch(self, x): | 
					
						
						|  | """Tries to split input by batch, run conv and concat results""" | 
					
						
						|  | b, _, _, _ = x.size() | 
					
						
						|  | if b == 1: | 
					
						
						|  | return x, False | 
					
						
						|  |  | 
					
						
						|  | if self.subsampling_conv_chunking_factor > 1: | 
					
						
						|  | cf = self.subsampling_conv_chunking_factor | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_ceil = 2**31 / self._conv_channels * self._stride * self._stride | 
					
						
						|  | p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) | 
					
						
						|  | cf = 2**p | 
					
						
						|  |  | 
					
						
						|  | new_batch_size = b // cf | 
					
						
						|  | if new_batch_size == 0: | 
					
						
						|  | return x, False | 
					
						
						|  |  | 
					
						
						|  | return torch.cat([self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]), True | 
					
						
						|  |  | 
					
						
						|  | def conv_split_by_channel(self, x): | 
					
						
						|  | """For dw convs, tries to split input by time, run conv and concat results""" | 
					
						
						|  | x = self.conv[0](x) | 
					
						
						|  | x = self.conv[1](x) | 
					
						
						|  |  | 
					
						
						|  | for i in range(self._sampling_num - 1): | 
					
						
						|  | _, c, t, _ = x.size() | 
					
						
						|  |  | 
					
						
						|  | if self.subsampling_conv_chunking_factor > 1: | 
					
						
						|  | cf = self.subsampling_conv_chunking_factor | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) | 
					
						
						|  | cf = 2**p | 
					
						
						|  |  | 
					
						
						|  | new_c = int(c // cf) | 
					
						
						|  | if new_c == 0: | 
					
						
						|  | new_c = 1 | 
					
						
						|  |  | 
					
						
						|  | new_t = int(t // cf) | 
					
						
						|  | if new_t == 0: | 
					
						
						|  | new_t = 1 | 
					
						
						|  |  | 
					
						
						|  | x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x = torch.cat( | 
					
						
						|  | [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], 2 | 
					
						
						|  | ) | 
					
						
						|  | x = self.conv[i * 3 + 4](x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | def channel_chunked_conv(self, conv, chunk_size, x): | 
					
						
						|  | """Performs channel chunked convolution""" | 
					
						
						|  |  | 
					
						
						|  | ind = 0 | 
					
						
						|  | out_chunks = [] | 
					
						
						|  | for chunk in torch.split(x, chunk_size, 1): | 
					
						
						|  | step = chunk.size()[1] | 
					
						
						|  |  | 
					
						
						|  | if self.is_causal: | 
					
						
						|  | chunk = nn.functional.pad( | 
					
						
						|  | chunk, | 
					
						
						|  | pad=( | 
					
						
						|  | self._kernel_size - 1, | 
					
						
						|  | self._stride - 1, | 
					
						
						|  | self._kernel_size - 1, | 
					
						
						|  | self._stride - 1, | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | ch_out = nn.functional.conv2d( | 
					
						
						|  | chunk, | 
					
						
						|  | conv.weight[ind : ind + step, :, :, :], | 
					
						
						|  | bias=conv.bias[ind : ind + step], | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=0, | 
					
						
						|  | groups=step, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | ch_out = nn.functional.conv2d( | 
					
						
						|  | chunk, | 
					
						
						|  | conv.weight[ind : ind + step, :, :, :], | 
					
						
						|  | bias=conv.bias[ind : ind + step], | 
					
						
						|  | stride=self._stride, | 
					
						
						|  | padding=self._left_padding, | 
					
						
						|  | groups=step, | 
					
						
						|  | ) | 
					
						
						|  | out_chunks.append(ch_out) | 
					
						
						|  | ind += step | 
					
						
						|  |  | 
					
						
						|  | return torch.cat(out_chunks, 1) | 
					
						
						|  |  | 
					
						
						|  | def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int): | 
					
						
						|  | if ( | 
					
						
						|  | subsampling_conv_chunking_factor != -1 | 
					
						
						|  | and subsampling_conv_chunking_factor != 1 | 
					
						
						|  | and subsampling_conv_chunking_factor % 2 != 0 | 
					
						
						|  | ): | 
					
						
						|  | raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2") | 
					
						
						|  | self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1): | 
					
						
						|  | """Calculates the output length of a Tensor passed through a convolution or max pooling layer""" | 
					
						
						|  | add_pad: float = all_paddings - kernel_size | 
					
						
						|  | one: float = 1.0 | 
					
						
						|  | for i in range(repeat_num): | 
					
						
						|  | lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one | 
					
						
						|  | if ceil_mode: | 
					
						
						|  | lengths = torch.ceil(lengths) | 
					
						
						|  | else: | 
					
						
						|  | lengths = torch.floor(lengths) | 
					
						
						|  | return lengths.to(dtype=torch.int) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AttModule(nn.Module): | 
					
						
						|  | """Attention abstraction module""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.export_mode = False | 
					
						
						|  |  | 
					
						
						|  | def set_export(self, mode=True): | 
					
						
						|  | """set the export mode""" | 
					
						
						|  | self.export_mode = mode | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | x: Tensor, | 
					
						
						|  | memory: Optional[Tensor] = None, | 
					
						
						|  | pos_emb: Optional[Tensor] = None, | 
					
						
						|  | att_mask: Optional[Tensor] = None, | 
					
						
						|  | ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: | 
					
						
						|  | """AttModule forward | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | input tensor. | 
					
						
						|  | memory: torch.Tensor, optional | 
					
						
						|  | memory tensor. | 
					
						
						|  | pos_emb: torch.Tensor, optional | 
					
						
						|  | positional encoder embedding. | 
					
						
						|  | att_mask: torch.Tensor, optional | 
					
						
						|  | attention mask tensor. | 
					
						
						|  | """ | 
					
						
						|  | return x, memory, pos_emb, att_mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AttBlock(Block, AttModule): | 
					
						
						|  | """Attention Block module to support both Attention and Block module.""" | 
					
						
						|  |  | 
					
						
						|  | def memory_dims(self, max_len=False): | 
					
						
						|  | """memory dimensions""" | 
					
						
						|  | return (1, self.input_size) | 
					
						
						|  |  | 
					
						
						|  | def masked_softmax( | 
					
						
						|  | scores, | 
					
						
						|  | mask: Optional[Tensor], | 
					
						
						|  | ): | 
					
						
						|  | if mask is not None: | 
					
						
						|  | mask = mask.unsqueeze(1).eq(0) | 
					
						
						|  | scores = scores.masked_fill(mask, -torch.inf) | 
					
						
						|  | attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) | 
					
						
						|  | else: | 
					
						
						|  | attn = torch.softmax(scores, dim=-1) | 
					
						
						|  | return attn | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MultiHeadedAttention(nn.Module): | 
					
						
						|  | """Multi-Head Attention layer with optional relative position embedding and GLU. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | n_head: int | 
					
						
						|  | the number of heads. | 
					
						
						|  | n_feat: int | 
					
						
						|  | input size features. | 
					
						
						|  | dropout_rate: float | 
					
						
						|  | dropout rate. | 
					
						
						|  | use_LN: bool | 
					
						
						|  | apply layer norm or not | 
					
						
						|  | dropout_at_output: bool | 
					
						
						|  | whether to apply dropout at output | 
					
						
						|  | attention_inner_dim: int, optional | 
					
						
						|  | the attention dimension used in the class, | 
					
						
						|  | it can be different from the input dimension n_feat. | 
					
						
						|  | default: -1 (equal to n_feat). | 
					
						
						|  | use_pt_scaled_dot_product_attention: bool, optional | 
					
						
						|  | if set True, use pytorch scaled dot product attention in training.  NOTE: this will NOT | 
					
						
						|  | be used in ONNX decoding due to a lack of support.  In that case, we use the original | 
					
						
						|  | attention implementation, which shows no regression. | 
					
						
						|  | default: False. | 
					
						
						|  | n_value: int, optional | 
					
						
						|  | if set to values other than -1, use a different dimension for value. With the default value (i.e. -1), it is backward compatible. | 
					
						
						|  | group_size: int, optional. must divide `n_head` | 
					
						
						|  | if group_size > 1:       GQA | 
					
						
						|  | if group_size = 1:       MHA | 
					
						
						|  | if group_size = n_head:  MQA | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | inv_sqrt_d_k: torch.jit.Final[float] | 
					
						
						|  | h: torch.jit.Final[int] | 
					
						
						|  | h_k: torch.jit.Final[int] | 
					
						
						|  | g: torch.jit.Final[int] | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | n_head, | 
					
						
						|  | n_feat, | 
					
						
						|  | dropout_rate, | 
					
						
						|  | attention_inner_dim=-1, | 
					
						
						|  | glu_type="swish", | 
					
						
						|  | bias_in_glu=True, | 
					
						
						|  | use_pt_scaled_dot_product_attention=False, | 
					
						
						|  | n_value=-1, | 
					
						
						|  | group_size: int = 1, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | if n_value == -1: | 
					
						
						|  | n_value = n_feat | 
					
						
						|  | if attention_inner_dim == -1: | 
					
						
						|  | attention_inner_dim = n_feat | 
					
						
						|  | assert attention_inner_dim % n_head == 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.d_k = attention_inner_dim // n_head | 
					
						
						|  | self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k) | 
					
						
						|  | self.h = n_head | 
					
						
						|  | assert n_head % group_size == 0, "group_size must divide n_head" | 
					
						
						|  | self.g = group_size | 
					
						
						|  | self.h_k = n_head // group_size | 
					
						
						|  |  | 
					
						
						|  | self.linear_q = nn.Linear(n_feat, attention_inner_dim) | 
					
						
						|  | self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size) | 
					
						
						|  | self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size) | 
					
						
						|  | self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value) | 
					
						
						|  |  | 
					
						
						|  | self.attn = torch.jit.Attribute(None, Optional[Tensor]) | 
					
						
						|  | self.dropout = nn.Dropout(p=dropout_rate) | 
					
						
						|  | self.dropout_rate = dropout_rate | 
					
						
						|  | self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention | 
					
						
						|  |  | 
					
						
						|  | if use_pt_scaled_dot_product_attention and group_size > 1: | 
					
						
						|  | raise ValueError("Cannot use PT Scaled Attention with GQA") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.quant_q = torch.ao.quantization.QuantStub() | 
					
						
						|  | self.quant_x = torch.ao.quantization.QuantStub() | 
					
						
						|  | self.dequant = torch.ao.quantization.DeQuantStub() | 
					
						
						|  | self.ffunc = torch.ao.nn.quantized.FloatFunctional() | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | query: Tensor, | 
					
						
						|  | key: Tensor, | 
					
						
						|  | value: Tensor, | 
					
						
						|  | pos_k: Tensor, | 
					
						
						|  | pos_v: Tensor, | 
					
						
						|  | mask: Optional[Tensor], | 
					
						
						|  | relative_attention_bias: Optional[Tensor] = None, | 
					
						
						|  | ): | 
					
						
						|  | """Compute 'Scaled Dot Product Attention'. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | query: torch.Tensor | 
					
						
						|  | query tensor (batch, time1, size) | 
					
						
						|  | key: torch.Tensor | 
					
						
						|  | key tensor (batch, time2, size) | 
					
						
						|  | value: torch.Tensor | 
					
						
						|  | value tensor (batch, time1, size) | 
					
						
						|  | pos_k: torch.Tensor | 
					
						
						|  | key tensor used for relative positional embedding. | 
					
						
						|  | pos_v: torch.Tensor | 
					
						
						|  | value tensor used for relative positional embedding. | 
					
						
						|  | mask: torch.Tensor | 
					
						
						|  | mask tensor (batch, time1, time2) | 
					
						
						|  | relative_attention_bias: torch.Tensor | 
					
						
						|  | bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) | 
					
						
						|  | """ | 
					
						
						|  | n_batch = query.size(0) | 
					
						
						|  |  | 
					
						
						|  | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) | 
					
						
						|  | k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k) | 
					
						
						|  | v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) | 
					
						
						|  | q = ( | 
					
						
						|  | q.transpose(1, 2) | 
					
						
						|  | if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting() | 
					
						
						|  | else q.transpose(1, 2) * self.inv_sqrt_d_k | 
					
						
						|  | ) | 
					
						
						|  | k = k.transpose(1, 2) | 
					
						
						|  | v = v.transpose(1, 2) | 
					
						
						|  |  | 
					
						
						|  | if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting(): | 
					
						
						|  | attn_mask = None | 
					
						
						|  | if mask is not None: | 
					
						
						|  | mask = mask.unsqueeze(1) | 
					
						
						|  | if relative_attention_bias is not None: | 
					
						
						|  | attn_mask = mask + relative_attention_bias | 
					
						
						|  | else: | 
					
						
						|  | attn_mask = mask | 
					
						
						|  | if mask.dtype != q.dtype: | 
					
						
						|  | attn_mask = attn_mask.to(q.dtype) | 
					
						
						|  |  | 
					
						
						|  | with torch.backends.cuda.sdp_kernel( | 
					
						
						|  | enable_flash=True, enable_math=True, enable_mem_efficient=True | 
					
						
						|  | ): | 
					
						
						|  | x = torch.nn.functional.scaled_dot_product_attention( | 
					
						
						|  | q, | 
					
						
						|  | k, | 
					
						
						|  | v, | 
					
						
						|  | attn_mask=attn_mask, | 
					
						
						|  | dropout_p=self.dropout_rate, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | if self.h != self.h_k: | 
					
						
						|  | q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k) | 
					
						
						|  | A = torch.einsum("b g h t d, b h s d -> b h t s", q, k) | 
					
						
						|  | else: | 
					
						
						|  | A = torch.matmul(q, k.transpose(-2, -1)) | 
					
						
						|  | if pos_k is not None: | 
					
						
						|  | if self.h != self.h_k: | 
					
						
						|  | B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) | 
					
						
						|  | else: | 
					
						
						|  | reshape_q = ( | 
					
						
						|  | q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0, 1) | 
					
						
						|  | ) | 
					
						
						|  | B = torch.matmul(reshape_q, pos_k.transpose(-2, -1)) | 
					
						
						|  | B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1)) | 
					
						
						|  | scores = A + B | 
					
						
						|  | else: | 
					
						
						|  | scores = A | 
					
						
						|  |  | 
					
						
						|  | if relative_attention_bias is not None: | 
					
						
						|  | scores = scores + relative_attention_bias | 
					
						
						|  |  | 
					
						
						|  | attn = masked_softmax(scores, mask) | 
					
						
						|  |  | 
					
						
						|  | self.attn = attn | 
					
						
						|  |  | 
					
						
						|  | p_attn = self.dropout(attn) | 
					
						
						|  | x = torch.matmul(p_attn.to(v.dtype), v) | 
					
						
						|  | if pos_v is not None: | 
					
						
						|  | reshape_attn = ( | 
					
						
						|  | p_attn.contiguous() | 
					
						
						|  | .view(n_batch * self.h, pos_v.size(0), pos_v.size(1)) | 
					
						
						|  | .transpose(0, 1) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | attn_v = ( | 
					
						
						|  | torch.matmul(reshape_attn, pos_v) | 
					
						
						|  | .transpose(0, 1) | 
					
						
						|  | .contiguous() | 
					
						
						|  | .view(n_batch, self.h, pos_v.size(0), self.d_k) | 
					
						
						|  | ) | 
					
						
						|  | x = x + attn_v | 
					
						
						|  | x = ( | 
					
						
						|  | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return self.linear_out(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def unfold_tensor(xs_pad, max_seq_len): | 
					
						
						|  | """ | 
					
						
						|  | For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, | 
					
						
						|  | this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. | 
					
						
						|  | Args: | 
					
						
						|  | xs_pad: N, T, D | 
					
						
						|  | """ | 
					
						
						|  | _, _, D = xs_pad.shape | 
					
						
						|  | xs_pad = xs_pad.transpose(-1, -2) | 
					
						
						|  |  | 
					
						
						|  | xs_pad = F.unfold( | 
					
						
						|  | xs_pad[..., None, :], | 
					
						
						|  | kernel_size=(1, max_seq_len), | 
					
						
						|  | stride=(1, max_seq_len), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | new_bsz, _, slen = xs_pad.shape | 
					
						
						|  |  | 
					
						
						|  | xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen) | 
					
						
						|  |  | 
					
						
						|  | xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous() | 
					
						
						|  |  | 
					
						
						|  | xs_pad = xs_pad.view(-1, max_seq_len, D) | 
					
						
						|  | return xs_pad | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MultiSequential(torch.nn.Sequential): | 
					
						
						|  | """Multi-input multi-output torch.nn.Sequential""" | 
					
						
						|  |  | 
					
						
						|  | @torch.jit.ignore | 
					
						
						|  | def forward(self, *args): | 
					
						
						|  | """Forward method implementation.""" | 
					
						
						|  | for m in self: | 
					
						
						|  | args = m(*args) | 
					
						
						|  | return args | 
					
						
						|  |  | 
					
						
						|  | def repeat(repeat_num, module_gen_fn): | 
					
						
						|  | """repeat module N times | 
					
						
						|  |  | 
					
						
						|  | :param int repeat_num: repeat time | 
					
						
						|  | :param function module_gen_fn: function to generate module | 
					
						
						|  | :return: repeated modules | 
					
						
						|  | :rtype: MultiSequential | 
					
						
						|  | """ | 
					
						
						|  | return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)]) | 
					
						
						|  |  | 
					
						
						|  | class ConformerEncoderLayer(nn.Module): | 
					
						
						|  | """ConformerEncoder Layer module. | 
					
						
						|  | for more details see conformer paper: | 
					
						
						|  | https://arxiv.org/abs/2005.08100 | 
					
						
						|  | This module implement the Conformer block layer. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | d_model: int | 
					
						
						|  | attention dim. | 
					
						
						|  | ext_pw_out_channel: int | 
					
						
						|  | if > 0, ext_pw_out_channel is a dim channel size | 
					
						
						|  | for the last pointwise conv after swish activation. | 
					
						
						|  | depthwise_seperable_out_channel: int | 
					
						
						|  | if set different to 0, the number of depthwise_seperable_out_channel | 
					
						
						|  | will be used as a channel_out of the second conv1d layer. | 
					
						
						|  | otherwise, it equal to 0, the second conv1d layer is skipped. | 
					
						
						|  | depthwise_multiplier: int | 
					
						
						|  | number of input_dim channels duplication. this value | 
					
						
						|  | will be used to compute the hidden channels of the Conv1D. | 
					
						
						|  | n_head: int | 
					
						
						|  | the number of heads for multihead attention module. | 
					
						
						|  | d_ffn: int | 
					
						
						|  | output size of the feed_forward blocks. | 
					
						
						|  | ext_pw_kernel_size: int | 
					
						
						|  | kernel size of the conv pointwise of the conformer. | 
					
						
						|  | kernel_size: int | 
					
						
						|  | kernel size. | 
					
						
						|  | dropout_rate: float | 
					
						
						|  | dropout rate. | 
					
						
						|  | causal: bool, optional | 
					
						
						|  | if set to True, convolution have no access | 
					
						
						|  | to future frames. default False. | 
					
						
						|  | batch_norm: bool, optional | 
					
						
						|  | if set to True, apply batchnorm before activation | 
					
						
						|  | in ConvModule layer of the conformer. | 
					
						
						|  | default False | 
					
						
						|  | activation: str, optional | 
					
						
						|  | activation function name, | 
					
						
						|  | one of ["relu", "swish", "sigmoid"], | 
					
						
						|  | sigmoid activation is only used with "glu_in_fnn=True", | 
					
						
						|  | default "relu". | 
					
						
						|  | chunk_se: int, optional | 
					
						
						|  | 0 for offline SE. | 
					
						
						|  | 1 for streaming SE, where mean is computed | 
					
						
						|  | by accumulated history until current chunk_se. | 
					
						
						|  | 2 for streaming SE, where mean is computed | 
					
						
						|  | by only the current chunk. | 
					
						
						|  | default 0. | 
					
						
						|  | chunk_size: int, optional | 
					
						
						|  | chunk_size for cnn. default 18 | 
					
						
						|  | conv_activation: str, optional | 
					
						
						|  | activation function used in ConvModule part | 
					
						
						|  | of the conformer, default "relu". | 
					
						
						|  | conv_glu_type: str, optional | 
					
						
						|  | activation function used for the glu inside | 
					
						
						|  | the ConvModule part of the conformer. | 
					
						
						|  | default: "sigmoid". | 
					
						
						|  | bias_in_glu: bool, optional | 
					
						
						|  | if set to True, use additive bias in the weight module | 
					
						
						|  | before GLU. | 
					
						
						|  | linear_glu_in_convm: bool, optional | 
					
						
						|  | if set to True, use GLULinear module, | 
					
						
						|  | otherwise, used GLUPointWiseConv module. | 
					
						
						|  | default to False. | 
					
						
						|  | attention_innner_dim: int, otional | 
					
						
						|  | if equal to -1, attention dim for linears k/q/v is | 
					
						
						|  | equal to d_model. otherwise attention_innner_dim is used. | 
					
						
						|  | default -1. | 
					
						
						|  | attention_glu_type: str, optional | 
					
						
						|  | activation function for glu used in the multihead attention, | 
					
						
						|  | default "swish". | 
					
						
						|  | activation_checkpointing: str, optional | 
					
						
						|  | a dictionarry of {"module","interval","offload"}, where | 
					
						
						|  | "module": str | 
					
						
						|  | accept ["transformer", "attention"] to select | 
					
						
						|  | which module should do activation checkpointing. | 
					
						
						|  | "interval": int, default 1, | 
					
						
						|  | interval of applying activation checkpointing, | 
					
						
						|  | interval = 1 means that we apply checkpointing | 
					
						
						|  | on every layer (if activation), otherwise, | 
					
						
						|  | we apply it every x interval. | 
					
						
						|  | "offload": bool, default False, | 
					
						
						|  | if set to True, we offload activation to cpu and | 
					
						
						|  | reload it during backward, otherwise, | 
					
						
						|  | we recalculate activation in backward. | 
					
						
						|  | default "". | 
					
						
						|  | export: bool, optional | 
					
						
						|  | if set to True, it remove the padding from convolutional layers | 
					
						
						|  | and allow the onnx conversion for inference. | 
					
						
						|  | default False. | 
					
						
						|  | use_pt_scaled_dot_product_attention: bool, optional | 
					
						
						|  | if set to True, use pytorch's scaled dot product attention implementation in training. | 
					
						
						|  | attn_group_sizes: int, optional | 
					
						
						|  | the number of groups to use for attention, default 1 (Multi-Head Attention), | 
					
						
						|  | 1 = typical Multi-Head Attention, | 
					
						
						|  | 1 < attn_group_sizes < attention_heads = Grouped-Query Attention | 
					
						
						|  | attn_group_sizes = attenion_heads = Multi-Query Attention | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | d_model=512, | 
					
						
						|  | ext_pw_out_channel=0, | 
					
						
						|  | depthwise_seperable_out_channel=256, | 
					
						
						|  | depthwise_multiplier=1, | 
					
						
						|  | n_head=4, | 
					
						
						|  | d_ffn=2048, | 
					
						
						|  | ext_pw_kernel_size=1, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | dropout_rate=0.1, | 
					
						
						|  | causal=False, | 
					
						
						|  | batch_norm=False, | 
					
						
						|  | activation="relu", | 
					
						
						|  | chunk_se=0, | 
					
						
						|  | chunk_size=18, | 
					
						
						|  | conv_activation="relu", | 
					
						
						|  | conv_glu_type="sigmoid", | 
					
						
						|  | bias_in_glu=True, | 
					
						
						|  | linear_glu_in_convm=False, | 
					
						
						|  | attention_innner_dim=-1, | 
					
						
						|  | attention_glu_type="swish", | 
					
						
						|  | activation_checkpointing="", | 
					
						
						|  | export=False, | 
					
						
						|  | use_pt_scaled_dot_product_attention=False, | 
					
						
						|  | attn_group_sizes: int = 1, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.feed_forward_in = FeedForward( | 
					
						
						|  | d_model=d_model, | 
					
						
						|  | d_inner=d_ffn, | 
					
						
						|  | dropout_rate=dropout_rate, | 
					
						
						|  | activation=activation, | 
					
						
						|  | bias_in_glu=bias_in_glu, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.self_attn = encoder_checkpoint_wrapper( | 
					
						
						|  | activation_checkpointing, | 
					
						
						|  | MultiHeadedAttention, | 
					
						
						|  | )( | 
					
						
						|  | MultiHeadedAttention( | 
					
						
						|  | n_head, | 
					
						
						|  | d_model, | 
					
						
						|  | dropout_rate, | 
					
						
						|  | attention_innner_dim, | 
					
						
						|  | attention_glu_type, | 
					
						
						|  | bias_in_glu, | 
					
						
						|  | use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, | 
					
						
						|  | group_size=attn_group_sizes, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | self.conv = ConvModule( | 
					
						
						|  | d_model, | 
					
						
						|  | ext_pw_out_channel, | 
					
						
						|  | depthwise_seperable_out_channel, | 
					
						
						|  | ext_pw_kernel_size, | 
					
						
						|  | kernel_size, | 
					
						
						|  | depthwise_multiplier, | 
					
						
						|  | dropout_rate, | 
					
						
						|  | causal, | 
					
						
						|  | batch_norm, | 
					
						
						|  | chunk_se, | 
					
						
						|  | chunk_size, | 
					
						
						|  | conv_activation, | 
					
						
						|  | conv_glu_type, | 
					
						
						|  | bias_in_glu, | 
					
						
						|  | linear_glu_in_convm, | 
					
						
						|  | export=export, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.feed_forward_out = FeedForward( | 
					
						
						|  | d_model=d_model, | 
					
						
						|  | d_inner=d_ffn, | 
					
						
						|  | dropout_rate=dropout_rate, | 
					
						
						|  | activation=activation, | 
					
						
						|  | bias_in_glu=bias_in_glu, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.layer_norm_att = nn.LayerNorm(d_model) | 
					
						
						|  | self.layer_norm = nn.LayerNorm(d_model) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | x, | 
					
						
						|  | pos_k, | 
					
						
						|  | pos_v, | 
					
						
						|  | mask, | 
					
						
						|  | relative_attention_bias: Optional[Tensor] = None, | 
					
						
						|  | ): | 
					
						
						|  | """ConformerEncoder forward. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x: torch.Tensor | 
					
						
						|  | input feature of shape (batch, max_time_in, size) | 
					
						
						|  | pos_k: torch.Tensor | 
					
						
						|  | positional key embedding. | 
					
						
						|  | mask: torch.Tensor | 
					
						
						|  | mask for x (batch, max_time_in) | 
					
						
						|  | relative_attention_bias: Optional[torch.Tensor] | 
					
						
						|  | bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) | 
					
						
						|  | """ | 
					
						
						|  | x = x + 0.5 * self.feed_forward_in(x) | 
					
						
						|  | norm_x = self.layer_norm_att(x) | 
					
						
						|  |  | 
					
						
						|  | x = x + self.self_attn( | 
					
						
						|  | norm_x, | 
					
						
						|  | norm_x, | 
					
						
						|  | norm_x, | 
					
						
						|  | pos_k, | 
					
						
						|  | pos_v, | 
					
						
						|  | mask, | 
					
						
						|  | relative_attention_bias=relative_attention_bias, | 
					
						
						|  | ) | 
					
						
						|  | x = x + self.conv(x) | 
					
						
						|  | x = x + 0.5 * self.feed_forward_out(x) | 
					
						
						|  |  | 
					
						
						|  | out = self.layer_norm(x) | 
					
						
						|  |  | 
					
						
						|  | return out, pos_k, pos_v, mask | 
					
						
						|  |  | 
					
						
						|  | class TransformerEncoderBase(abc.ABC, nn.Module): | 
					
						
						|  | """The Base class for Transformer based encoders | 
					
						
						|  |  | 
					
						
						|  | Please set causal = True in streaming model | 
					
						
						|  | Args: | 
					
						
						|  | input_size: int | 
					
						
						|  | input feature dimension. | 
					
						
						|  | chunk_size: int, list(int) | 
					
						
						|  | Number of frames for each chunk | 
					
						
						|  | This variable can take 2 forms: | 
					
						
						|  | int:  Used for inference, or single chunk size training | 
					
						
						|  | list(int) : Used only for variable chunk size training | 
					
						
						|  | Some examples for the 2 cases: | 
					
						
						|  | chunk_size = 12 | 
					
						
						|  | chunk_size = [6, 8, 12, 24] | 
					
						
						|  | left_chunk: int, list(int) | 
					
						
						|  | Number of chunks used for masking in streaming mode. | 
					
						
						|  | This variable can take 2 forms: | 
					
						
						|  | int:  Used for inference, or single chunk size training | 
					
						
						|  | list(int) : Used only for variable chunk size training. When | 
					
						
						|  | chunk_size is a list, left_chunk must be a list with same length. | 
					
						
						|  | Some examples for the 2 cases: | 
					
						
						|  | left_chunk = 6 | 
					
						
						|  | left_chunk = [12, 9, 6, 3] | 
					
						
						|  | attention_dim: int, optional | 
					
						
						|  | attention dimension. default 256. | 
					
						
						|  | attention_heads: int, optional | 
					
						
						|  | the number of heads. default 4 | 
					
						
						|  | input_layer: str, optional | 
					
						
						|  | input layer type before Conformer, | 
					
						
						|  | one of ["linear", "conv2d", "custom", "vgg2l", "embed"], | 
					
						
						|  | default "conv2d" | 
					
						
						|  | cnn_out: int, optional | 
					
						
						|  | the number of CNN channels before Conformer. | 
					
						
						|  | default -1. | 
					
						
						|  | cnn_layer_norm: bool, optional | 
					
						
						|  | layer norm between Conformer and the first CNN. | 
					
						
						|  | default False. | 
					
						
						|  | time_reduction: int, optional | 
					
						
						|  | time reduction factor | 
					
						
						|  | default 4 | 
					
						
						|  | dropout_rate: float, optional | 
					
						
						|  | dropout rate. default 0.1 | 
					
						
						|  | padding_idx: int, optional | 
					
						
						|  | padding index for input_layer=embed | 
					
						
						|  | default -1 | 
					
						
						|  | relative_attention_bias_args: dict, optional | 
					
						
						|  | use more efficient scalar bias-based relative multihead attention (Q*K^T + B) | 
					
						
						|  | implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias | 
					
						
						|  | usage: relative_attention_bias_args={"type": t5/alibi} | 
					
						
						|  | additional method-specific arguments can be provided (see transformer_base.py) | 
					
						
						|  | positional_dropout_rate: float, optional | 
					
						
						|  | dropout rate after positional encoding. default 0.0 | 
					
						
						|  | nemo_conv_settings: dict, optional | 
					
						
						|  | A dictionary of settings for NeMo Subsampling. | 
					
						
						|  | default None | 
					
						
						|  | conv2d_extra_padding: str, optional | 
					
						
						|  | Add extra padding in conv2d subsampling layers. Choices are | 
					
						
						|  | (feat, feat_time, none, True). | 
					
						
						|  | if True or feat_time, the extra padding is added into non full | 
					
						
						|  | supraframe utts in batch. | 
					
						
						|  | Default: none | 
					
						
						|  | attention_group_size: int, optional | 
					
						
						|  | the number of groups to use for attention, default 1 (Multi-Head Attention), | 
					
						
						|  | 1 = typical Multi-Head Attention, | 
					
						
						|  | 1 < attention_group_size < attention_heads = Grouped-Query Attention | 
					
						
						|  | attention_group_size = attenion_heads = Multi-Query Attention | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | input_size, | 
					
						
						|  | chunk_size, | 
					
						
						|  | left_chunk, | 
					
						
						|  | attention_dim=256, | 
					
						
						|  | attention_heads=4, | 
					
						
						|  | input_layer="nemo_conv", | 
					
						
						|  | cnn_out=-1, | 
					
						
						|  | cnn_layer_norm=False, | 
					
						
						|  | time_reduction=4, | 
					
						
						|  | dropout_rate=0.0, | 
					
						
						|  | padding_idx=-1, | 
					
						
						|  | relative_attention_bias_args=None, | 
					
						
						|  | positional_dropout_rate=0.0, | 
					
						
						|  | nemo_conv_settings=None, | 
					
						
						|  | conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", | 
					
						
						|  | attention_group_size=1, | 
					
						
						|  | encoder_embedding_config=None, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.input_size = input_size | 
					
						
						|  | self.input_layer = input_layer | 
					
						
						|  | self.chunk_size = chunk_size | 
					
						
						|  | self.left_chunk = left_chunk | 
					
						
						|  | self.attention_dim = attention_dim | 
					
						
						|  | self.num_heads = attention_heads | 
					
						
						|  | self.attention_group_size = attention_group_size | 
					
						
						|  | self.time_reduction = time_reduction | 
					
						
						|  | self.nemo_conv_settings = nemo_conv_settings | 
					
						
						|  | self.encoder_embedding_config = encoder_embedding_config | 
					
						
						|  |  | 
					
						
						|  | if self.input_layer == "nemo_conv": | 
					
						
						|  | default_nemo_conv_settings = { | 
					
						
						|  | "subsampling": "dw_striding", | 
					
						
						|  | "subsampling_factor": self.time_reduction, | 
					
						
						|  | "feat_in": input_size, | 
					
						
						|  | "feat_out": attention_dim, | 
					
						
						|  | "conv_channels": 256, | 
					
						
						|  | "subsampling_conv_chunking_factor": 1, | 
					
						
						|  | "activation": nn.ReLU(), | 
					
						
						|  | "is_causal": False, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | if nemo_conv_settings: | 
					
						
						|  | default_nemo_conv_settings.update(nemo_conv_settings) | 
					
						
						|  | for i in ["subsampling_factor", "feat_in", "feat_out"]: | 
					
						
						|  | assert ( | 
					
						
						|  | i not in nemo_conv_settings | 
					
						
						|  | ), "{i} should be specified outside of the NeMo dictionary" | 
					
						
						|  |  | 
					
						
						|  | self.embed = NemoConvSubsampling( | 
					
						
						|  | **default_nemo_conv_settings, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError("unknown input_layer: " + input_layer) | 
					
						
						|  |  | 
					
						
						|  | self.pos_emb = AbsolutePositionalEncoding(attention_dim, positional_dropout_rate) | 
					
						
						|  |  | 
					
						
						|  | self.relative_attention_bias_type = ( | 
					
						
						|  | relative_attention_bias_args.get("type") if relative_attention_bias_args else None | 
					
						
						|  | ) | 
					
						
						|  | if self.relative_attention_bias_type == "t5": | 
					
						
						|  | assert ( | 
					
						
						|  | self.num_heads % self.attention_group_size == 0 | 
					
						
						|  | ), "attention_group_size must divide n_head" | 
					
						
						|  | self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( | 
					
						
						|  | self.num_heads // self.attention_group_size, | 
					
						
						|  | max_distance=relative_attention_bias_args.get("t5_bias_max_distance", 1000), | 
					
						
						|  | symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False), | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def post_init(self, init_model_config): | 
					
						
						|  |  | 
					
						
						|  | pretrained_speech_encoder_path = init_model_config.get('pretrained_speech_encoder_path', None) | 
					
						
						|  | if pretrained_speech_encoder_path: | 
					
						
						|  | model_state = torch.load(pretrained_speech_encoder_path, map_location="cpu") | 
					
						
						|  | encoder_state_dict = {} | 
					
						
						|  | for k, v in model_state.items(): | 
					
						
						|  | if "encoder." in k: | 
					
						
						|  | tmp_k = k.replace("encoder.", "") | 
					
						
						|  | encoder_state_dict[tmp_k] = v | 
					
						
						|  |  | 
					
						
						|  | if hasattr(self, "encoder_embedding"): | 
					
						
						|  | del self.encoder_embedding | 
					
						
						|  | self.load_state_dict(encoder_state_dict) | 
					
						
						|  |  | 
					
						
						|  | if not hasattr(self, "encoder_embedding"): | 
					
						
						|  | self.encoder_embedding = MeanVarianceNormLayer(self.encoder_embedding_config["input_size"]) | 
					
						
						|  |  | 
					
						
						|  | mean_file = init_model_config.get('mean_file', None) | 
					
						
						|  | invstd_file = init_model_config.get('invstd_file', None) | 
					
						
						|  | if mean_file is not None and invstd_file is not None: | 
					
						
						|  | self.encoder_embedding.load_mean_invstd(mean_file, invstd_file) | 
					
						
						|  |  | 
					
						
						|  | def compute_lens_change(self, feature_lens): | 
					
						
						|  | """feature_lens: int | 
					
						
						|  | return updated feature lens. | 
					
						
						|  |  | 
					
						
						|  | This used to return a different lambda function for each case that computed | 
					
						
						|  | the right thing.  That does not work within Torchscript.  If you really | 
					
						
						|  | need this to be faster, create nn.Module()-s for all the cases and return | 
					
						
						|  | one of them.  Torchscript does support that. | 
					
						
						|  | """ | 
					
						
						|  | if self.input_layer == "nemo_conv": | 
					
						
						|  |  | 
					
						
						|  | subsampling_causal_cond = self.nemo_conv_settings.get("subsampling", "dw_striding") in [ | 
					
						
						|  | "dw_striding", | 
					
						
						|  | "striding", | 
					
						
						|  | "striding_conv1d", | 
					
						
						|  | ] | 
					
						
						|  | is_causal = self.nemo_conv_settings.get("is_causal", False) | 
					
						
						|  | if is_causal and subsampling_causal_cond: | 
					
						
						|  | lens_change = ( | 
					
						
						|  | torch.ceil(feature_lens / self.time_reduction).long() | 
					
						
						|  | if isinstance(feature_lens, Tensor) | 
					
						
						|  | else math.ceil(feature_lens / self.time_reduction) | 
					
						
						|  | ) | 
					
						
						|  | feature_lens_remainder = feature_lens % self.time_reduction | 
					
						
						|  | if isinstance(feature_lens, Tensor): | 
					
						
						|  | lens_change[feature_lens_remainder != 1] += 1 | 
					
						
						|  | elif feature_lens_remainder != 1: | 
					
						
						|  | lens_change += 1 | 
					
						
						|  | return lens_change | 
					
						
						|  | ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil | 
					
						
						|  | return ceil_func(feature_lens / self.time_reduction) | 
					
						
						|  |  | 
					
						
						|  | @abc.abstractmethod | 
					
						
						|  | def forward(self): | 
					
						
						|  | """Abstract forward method implementation.""" | 
					
						
						|  |  | 
					
						
						|  | def _chunk_size_selection(self, chunk_size=None, left_chunk=None): | 
					
						
						|  | """If chunk size is a list, we will randomly select a chunk size.""" | 
					
						
						|  |  | 
					
						
						|  | if chunk_size is None: | 
					
						
						|  | chunk_size = self.chunk_size | 
					
						
						|  | if left_chunk is None: | 
					
						
						|  | left_chunk = self.left_chunk | 
					
						
						|  | if isinstance(chunk_size, list): | 
					
						
						|  |  | 
					
						
						|  | chunk_size_index = int(torch.randint(low=0, high=len(chunk_size), size=(1,))) | 
					
						
						|  | chunk_size_train_eff = chunk_size[chunk_size_index] | 
					
						
						|  | if not isinstance(left_chunk, list): | 
					
						
						|  | raise ValueError("Since chunk_size is a list, left_chunk must be a list") | 
					
						
						|  | if len(left_chunk) != len(chunk_size): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "The length of left_chunk must be the same as length of chunk_size." | 
					
						
						|  | ) | 
					
						
						|  | left_chunk_train_eff = left_chunk[chunk_size_index] | 
					
						
						|  | else: | 
					
						
						|  | chunk_size_train_eff = chunk_size | 
					
						
						|  | left_chunk_train_eff = left_chunk | 
					
						
						|  |  | 
					
						
						|  | return chunk_size_train_eff, left_chunk_train_eff | 
					
						
						|  |  | 
					
						
						|  | def _get_embed_class(self, embed): | 
					
						
						|  |  | 
					
						
						|  | is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) | 
					
						
						|  | is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) | 
					
						
						|  | embed_class = embed | 
					
						
						|  | if is_embed_using_act_chkpt: | 
					
						
						|  | embed_class = embed._checkpoint_wrapped_module | 
					
						
						|  | if is_embed_fsdp_wrapped: | 
					
						
						|  | embed_class = embed.module | 
					
						
						|  | return embed_class | 
					
						
						|  |  | 
					
						
						|  | def _forward_embeddings_core(self, input_tensor, masks): | 
					
						
						|  | embed_class = self._get_embed_class(self.embed) | 
					
						
						|  | assert isinstance(embed_class, NemoConvSubsampling) | 
					
						
						|  | input_tensor, masks = self.embed(input_tensor, masks) | 
					
						
						|  | return input_tensor, masks | 
					
						
						|  |  | 
					
						
						|  | def _position_embedding(self, input_tensor): | 
					
						
						|  | pos_k = None | 
					
						
						|  | pos_v = None | 
					
						
						|  | if self.relative_attention_bias_layer is None: | 
					
						
						|  | input_tensor = self.pos_emb(input_tensor) | 
					
						
						|  | return pos_k, pos_v | 
					
						
						|  |  | 
					
						
						|  | def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): | 
					
						
						|  | chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( | 
					
						
						|  | chunk_size, left_chunk | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) | 
					
						
						|  |  | 
					
						
						|  | if self.training and np.random.rand() > 0.5: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | chunk_start_idx = seq_len - chunk_start_idx | 
					
						
						|  | chunk_start_idx = chunk_start_idx[::-1] | 
					
						
						|  | chunk_start_idx = chunk_start_idx[:-1] | 
					
						
						|  | chunk_start_idx = np.insert(chunk_start_idx, 0, 0) | 
					
						
						|  |  | 
					
						
						|  | enc_streaming_mask = ( | 
					
						
						|  | adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk_train_eff) | 
					
						
						|  | .unsqueeze(0) | 
					
						
						|  | .expand([batch_size, -1, -1]) | 
					
						
						|  | ) | 
					
						
						|  | return enc_streaming_mask | 
					
						
						|  |  | 
					
						
						|  | def forward_embeddings(self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None): | 
					
						
						|  | """Forwarding the inputs through the top embedding layers | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | xs_pad: torch.Tensor | 
					
						
						|  | input tensor | 
					
						
						|  | masks: torch.Tensor | 
					
						
						|  | input mask | 
					
						
						|  | chunk_size_nc: (optional, default is None) chunk size for non-causal layers | 
					
						
						|  | left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | seq_len = int(self.compute_lens_change(xs_pad.shape[1])) | 
					
						
						|  | if seq_len <= 0: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"""The squence length after time reduction is invalid: {seq_len}. | 
					
						
						|  | Your input feature is too short. Consider filtering out the very | 
					
						
						|  | short sentence from data loader""", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | batch_size = xs_pad.shape[0] | 
					
						
						|  |  | 
					
						
						|  | enc_streaming_mask = self._streaming_mask( | 
					
						
						|  | seq_len, batch_size, self.chunk_size, self.left_chunk | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if xs_pad.device != "cpu": | 
					
						
						|  | enc_streaming_mask = enc_streaming_mask.to(xs_pad.device) | 
					
						
						|  |  | 
					
						
						|  | input_tensor = xs_pad | 
					
						
						|  | input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) | 
					
						
						|  |  | 
					
						
						|  | streaming_mask = enc_streaming_mask | 
					
						
						|  | if streaming_mask is not None and masks is not None: | 
					
						
						|  | hs_mask = masks & streaming_mask | 
					
						
						|  | elif masks is not None: | 
					
						
						|  | hs_mask = masks | 
					
						
						|  | else: | 
					
						
						|  | hs_mask = streaming_mask | 
					
						
						|  |  | 
					
						
						|  | if chunk_size_nc is not None: | 
					
						
						|  | enc_streaming_mask_nc = self._streaming_mask( | 
					
						
						|  | seq_len, batch_size, chunk_size_nc, left_chunk_nc | 
					
						
						|  | ) | 
					
						
						|  | if xs_pad.device != "cpu": | 
					
						
						|  | enc_streaming_mask_nc = enc_streaming_mask_nc.to(xs_pad.device) | 
					
						
						|  | if masks is not None: | 
					
						
						|  | hs_mask_nc = masks & enc_streaming_mask_nc | 
					
						
						|  | else: | 
					
						
						|  | hs_mask_nc = enc_streaming_mask_nc | 
					
						
						|  | else: | 
					
						
						|  | hs_mask_nc = None | 
					
						
						|  |  | 
					
						
						|  | pos_k, pos_v = self._position_embedding(input_tensor) | 
					
						
						|  |  | 
					
						
						|  | if chunk_size_nc is None: | 
					
						
						|  | return input_tensor, pos_k, pos_v, hs_mask, masks | 
					
						
						|  | return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc | 
					
						
						|  |  | 
					
						
						|  | def get_offset(self): | 
					
						
						|  | """Returns offset used when retaining inputs for decoding. | 
					
						
						|  |  | 
					
						
						|  | This is essentially, how many additional frames have to be added to | 
					
						
						|  | the front-end CNN input to ensure it can produce a single output. | 
					
						
						|  | So if the "padding" parameter is 0, typically offset will be > 0. | 
					
						
						|  | """ | 
					
						
						|  | return get_offset(self.input_layer, self.time_reduction) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_offset(input_layer: str, time_reduction: int): | 
					
						
						|  | """Get an offset. We will use the offset for determining #frames of a subsampled feature. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input_layer (str): Type of an input layer | 
					
						
						|  | time_reduction (int): time reduction factor for downsampling a feature | 
					
						
						|  | Returns: | 
					
						
						|  | int: offset | 
					
						
						|  | """ | 
					
						
						|  | if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: | 
					
						
						|  | return 3 | 
					
						
						|  | if input_layer in ("conv2d",) and time_reduction == 6: | 
					
						
						|  | return 1 | 
					
						
						|  | if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: | 
					
						
						|  | return 7 | 
					
						
						|  | return 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ConformerEncoder(TransformerEncoderBase): | 
					
						
						|  | """ConformerEncoder module. | 
					
						
						|  | see original paper for more details: | 
					
						
						|  | https://arxiv.org/abs/2005.08100 | 
					
						
						|  |  | 
					
						
						|  | Please set causal = True in streaming model | 
					
						
						|  | Args: | 
					
						
						|  | input_size: int | 
					
						
						|  | input feature dimension. | 
					
						
						|  | chunk_size: int, list(int) | 
					
						
						|  | Number of frames for each chunk | 
					
						
						|  | This variable can take 2 forms: | 
					
						
						|  | int:  Used for inference, or single chunk size training | 
					
						
						|  | list(int) : Used only for variable chunk size training | 
					
						
						|  | Some examples for the 2 cases: | 
					
						
						|  | chunk_size = 12 | 
					
						
						|  | chunk_size = [6, 8, 12, 24] | 
					
						
						|  | left_chunk: int, list(int) | 
					
						
						|  | Number of chunks used for masking in streaming mode. | 
					
						
						|  | This variable can take 2 forms: | 
					
						
						|  | int:  Used for inference, or single chunk size training | 
					
						
						|  | list(int) : Used only for variable chunk size training. When | 
					
						
						|  | chunk_size is a list, left_chunk must be a list with same length. | 
					
						
						|  | Some examples for the 2 cases: | 
					
						
						|  | left_chunk = 6 | 
					
						
						|  | left_chunk = [12, 9, 6, 3] | 
					
						
						|  | left_chunk: int | 
					
						
						|  | number of chunks used for masking in streaming mode. | 
					
						
						|  | num_lang: int | 
					
						
						|  | This parameter is used to store the number of languages in the lang_dict, | 
					
						
						|  | only used for multiseed/multilingual models. default None. | 
					
						
						|  | attention_dim: int, optional | 
					
						
						|  | attention dimension. default 256. | 
					
						
						|  | attention_heads: int, optional | 
					
						
						|  | the number of heads. default 4 | 
					
						
						|  | linear_units: | 
					
						
						|  | the number of units of position-wise feed forward. | 
					
						
						|  | default 2048 | 
					
						
						|  | num_block: | 
					
						
						|  | number of Transformer layer. default 6 | 
					
						
						|  | dropout_rate: float, optional | 
					
						
						|  | dropout rate. default 0.1 | 
					
						
						|  | input_layer: str, optional | 
					
						
						|  | input layer type before Conformer, | 
					
						
						|  | one of ["linear", "conv2d", "custom", "vgg2l", "embed"], | 
					
						
						|  | default "conv2d" | 
					
						
						|  | causal: bool, optional | 
					
						
						|  | if set to True, convolution have no access | 
					
						
						|  | to future frames. default False. | 
					
						
						|  | batch_norm: bool, optional | 
					
						
						|  | if set to True, apply batchnorm before activation | 
					
						
						|  | in ConvModule layer of the conformer. | 
					
						
						|  | default False | 
					
						
						|  | cnn_out: int, optional | 
					
						
						|  | the number of CNN channels before Conformer. | 
					
						
						|  | default -1. | 
					
						
						|  | cnn_layer_norm: bool, optional | 
					
						
						|  | layer norm between Conformer and the first CNN. | 
					
						
						|  | default False. | 
					
						
						|  | ext_pw_out_channel: int, optional | 
					
						
						|  | the number of channel for CNN | 
					
						
						|  | before depthwise_seperable_CNN. | 
					
						
						|  | If 0 then use linear. default 0. | 
					
						
						|  | ext_pw_kernel_size: int, optional | 
					
						
						|  | kernel size of N before depthwise_seperable_CNN. | 
					
						
						|  | only work for ext_pw_out_channel > 0. | 
					
						
						|  | default 1 | 
					
						
						|  | depthwise_seperable_out_channel: int, optional | 
					
						
						|  | the number of channel for | 
					
						
						|  | depthwise_seperable_CNN. | 
					
						
						|  | default 256. | 
					
						
						|  | depthwise_multiplier: int, optional | 
					
						
						|  | the number of multiplier for | 
					
						
						|  | depthwise_seperable_CNN. | 
					
						
						|  | default 1. | 
					
						
						|  | chunk_se: int, optional | 
					
						
						|  | 0 for offline SE. | 
					
						
						|  | 1 for streaming SE, where mean is computed | 
					
						
						|  | by accumulated history until current chunk_se. | 
					
						
						|  | 2 for streaming SE, where mean is computed | 
					
						
						|  | by only the current chunk. | 
					
						
						|  | default 0. | 
					
						
						|  | kernel_size: int, optional | 
					
						
						|  | the number of kernels for depthwise_seperable_CNN. | 
					
						
						|  | default 3. | 
					
						
						|  | activation: str, optional | 
					
						
						|  | FeedForward block activation. | 
					
						
						|  | one of ["relu", "swish", "sigmoid"] | 
					
						
						|  | default "relu". | 
					
						
						|  | conv_activation: str, optional | 
					
						
						|  | activation function used in ConvModule part | 
					
						
						|  | of the conformer, default "relu". | 
					
						
						|  | conv_glu_type: str, otional | 
					
						
						|  | activation used use glu in depthwise_seperable_CNN, | 
					
						
						|  | default "sigmoid" | 
					
						
						|  | bias_in_glu: bool, optional | 
					
						
						|  | if set to True, use additive bias in the weight module | 
					
						
						|  | before GLU. default True | 
					
						
						|  | linear_glu_in_convm: bool, optional | 
					
						
						|  | if set to True, use GLULinear module, | 
					
						
						|  | otherwise, used GLUPointWiseConv module. | 
					
						
						|  | default to False. | 
					
						
						|  | attention_glu_type: str | 
					
						
						|  | only work for glu_in_attention !=0 | 
					
						
						|  | default "swish". | 
					
						
						|  | export: bool, optional | 
					
						
						|  | if set to True, it remove the padding from convolutional layers | 
					
						
						|  | and allow the onnx conversion for inference. | 
					
						
						|  | default False. | 
					
						
						|  | activation_checkpointing: str, optional | 
					
						
						|  | a dictionarry of {"module","interval","offload"}, where | 
					
						
						|  | "module": str | 
					
						
						|  | accept ["transformer", "attention"] to select | 
					
						
						|  | which module should do activation checkpointing. | 
					
						
						|  | "interval": int, default 1, | 
					
						
						|  | interval of applying activation checkpointing, | 
					
						
						|  | interval = 1 means that we apply checkpointing | 
					
						
						|  | on every layer (if activation), otherwise, | 
					
						
						|  | we apply it every x interval. | 
					
						
						|  | "offload": bool, default False, | 
					
						
						|  | if set to True, we offload activation to cpu and | 
					
						
						|  | reload it during backward, otherwise, | 
					
						
						|  | we recalculate activation in backward. | 
					
						
						|  | default "". | 
					
						
						|  | extra_layer_output_idx: int | 
					
						
						|  | the layer index to be exposed. | 
					
						
						|  | relative_attention_bias_args: dict, optional | 
					
						
						|  | use more efficient scalar bias-based relative multihead attention (Q*K^T + B) | 
					
						
						|  | implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias | 
					
						
						|  | usage: relative_attention_bias_args={"type": t5/alibi} | 
					
						
						|  | additional method-specific arguments can be provided (see transformer_base.py) | 
					
						
						|  | time_reduction: int optional | 
					
						
						|  | time reduction factor | 
					
						
						|  | default 4 | 
					
						
						|  | use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention | 
					
						
						|  | in training. | 
					
						
						|  | Default: False | 
					
						
						|  | nemo_conv_settings: dict, optional | 
					
						
						|  | A dictionary of settings for NeMo Subsampling. | 
					
						
						|  | default: None | 
					
						
						|  | usage: nemo_conv_settings= | 
					
						
						|  | { | 
					
						
						|  | "subsampling": | 
					
						
						|  | dw_striding/striding/dw_striding_conv1d/striding_conv1d, | 
					
						
						|  | "conv_channels": int, | 
					
						
						|  | "subsampling_conv_chunking_factor": int, | 
					
						
						|  | "is_causal": True/False | 
					
						
						|  | } | 
					
						
						|  | conv2d_extra_padding: str, optional | 
					
						
						|  | Add extra padding in conv2d subsampling layers. Choices are | 
					
						
						|  | (feat, feat_time, none, True) | 
					
						
						|  | Default: none | 
					
						
						|  | replication_pad_for_subsample_embedding:  For batched-streaming decoding, use | 
					
						
						|  | "replication" padding for the cache at start of utterance. | 
					
						
						|  | Default: False | 
					
						
						|  | attention_group_size: int, optional | 
					
						
						|  | the number of groups to use for attention, default 1 (Multi-Head Attention), | 
					
						
						|  | 1 = typical Multi-Head Attention, | 
					
						
						|  | 1 < attention_group_size < attention_heads = Grouped-Query Attention | 
					
						
						|  | attention_group_size = attenion_heads = Multi-Query Attention | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | extra_multi_layer_output_idxs: List[int] | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | input_size, | 
					
						
						|  | chunk_size, | 
					
						
						|  | left_chunk, | 
					
						
						|  | num_lang=None, | 
					
						
						|  | attention_dim=256, | 
					
						
						|  | attention_heads=4, | 
					
						
						|  | linear_units=2048, | 
					
						
						|  | num_blocks=6, | 
					
						
						|  | dropout_rate=0.1, | 
					
						
						|  | input_layer="nemo_conv", | 
					
						
						|  | causal=True, | 
					
						
						|  | batch_norm=False, | 
					
						
						|  | cnn_out=-1, | 
					
						
						|  | cnn_layer_norm=False, | 
					
						
						|  | ext_pw_out_channel=0, | 
					
						
						|  | ext_pw_kernel_size=1, | 
					
						
						|  | depthwise_seperable_out_channel=256, | 
					
						
						|  | depthwise_multiplier=1, | 
					
						
						|  | chunk_se=0, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | activation="relu", | 
					
						
						|  | conv_activation="relu", | 
					
						
						|  | conv_glu_type="sigmoid", | 
					
						
						|  | bias_in_glu=True, | 
					
						
						|  | linear_glu_in_convm=False, | 
					
						
						|  | attention_glu_type="swish", | 
					
						
						|  | export=False, | 
					
						
						|  | extra_layer_output_idx=-1, | 
					
						
						|  | extra_multi_layer_output_idxs=[], | 
					
						
						|  | activation_checkpointing="", | 
					
						
						|  | relative_attention_bias_args=None, | 
					
						
						|  | time_reduction=4, | 
					
						
						|  | use_pt_scaled_dot_product_attention=False, | 
					
						
						|  | nemo_conv_settings=None, | 
					
						
						|  | conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", | 
					
						
						|  | replication_pad_for_subsample_embedding=False, | 
					
						
						|  | attention_group_size=1, | 
					
						
						|  | encoder_embedding_config=None, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__( | 
					
						
						|  | input_size, | 
					
						
						|  | chunk_size, | 
					
						
						|  | left_chunk, | 
					
						
						|  | attention_dim, | 
					
						
						|  | attention_heads, | 
					
						
						|  | input_layer, | 
					
						
						|  | cnn_out, | 
					
						
						|  | cnn_layer_norm, | 
					
						
						|  | time_reduction, | 
					
						
						|  | dropout_rate=dropout_rate, | 
					
						
						|  | relative_attention_bias_args=relative_attention_bias_args, | 
					
						
						|  | positional_dropout_rate=0.0, | 
					
						
						|  | nemo_conv_settings=nemo_conv_settings, | 
					
						
						|  | conv2d_extra_padding=conv2d_extra_padding, | 
					
						
						|  | attention_group_size=attention_group_size, | 
					
						
						|  | encoder_embedding_config=encoder_embedding_config, | 
					
						
						|  | ) | 
					
						
						|  | self.num_blocks = num_blocks | 
					
						
						|  | self.num_lang = num_lang | 
					
						
						|  | self.kernel_size = kernel_size | 
					
						
						|  | self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(self.embed) | 
					
						
						|  | self.replication_pad_for_subsample_embedding: bool = replication_pad_for_subsample_embedding | 
					
						
						|  | assert self.num_heads % attention_group_size == 0, "attention_group_size must divide n_head" | 
					
						
						|  | self.num_heads_k = self.num_heads // attention_group_size | 
					
						
						|  |  | 
					
						
						|  | self.encoders = repeat( | 
					
						
						|  | num_blocks, | 
					
						
						|  | lambda i: encoder_checkpoint_wrapper( | 
					
						
						|  | activation_checkpointing, ConformerEncoderLayer, i | 
					
						
						|  | )( | 
					
						
						|  | ConformerEncoderLayer( | 
					
						
						|  | d_model=attention_dim, | 
					
						
						|  | ext_pw_out_channel=ext_pw_out_channel, | 
					
						
						|  | depthwise_seperable_out_channel=depthwise_seperable_out_channel, | 
					
						
						|  | depthwise_multiplier=depthwise_multiplier, | 
					
						
						|  | n_head=attention_heads, | 
					
						
						|  | d_ffn=linear_units, | 
					
						
						|  | ext_pw_kernel_size=ext_pw_kernel_size, | 
					
						
						|  | kernel_size=kernel_size, | 
					
						
						|  | dropout_rate=dropout_rate, | 
					
						
						|  | causal=causal, | 
					
						
						|  | batch_norm=batch_norm, | 
					
						
						|  | activation=activation, | 
					
						
						|  | chunk_se=chunk_se, | 
					
						
						|  | chunk_size=chunk_size, | 
					
						
						|  | conv_activation=conv_activation, | 
					
						
						|  | conv_glu_type=conv_glu_type, | 
					
						
						|  | bias_in_glu=bias_in_glu, | 
					
						
						|  | linear_glu_in_convm=linear_glu_in_convm, | 
					
						
						|  | attention_glu_type=attention_glu_type, | 
					
						
						|  | activation_checkpointing=attn_checkpointing(activation_checkpointing, i), | 
					
						
						|  | export=export, | 
					
						
						|  | use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, | 
					
						
						|  | attn_group_sizes=attention_group_size, | 
					
						
						|  | ) | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | self.extra_layer_output_idx = extra_layer_output_idx | 
					
						
						|  | self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.register_buffer("dev_type", torch.zeros(()), persistent=False) | 
					
						
						|  |  | 
					
						
						|  | def init_relative_attention_bias(self, input_tensor): | 
					
						
						|  | if self.relative_attention_bias_layer: | 
					
						
						|  | return self.relative_attention_bias_layer(input_tensor) | 
					
						
						|  |  | 
					
						
						|  | def calculate_hs_mask(self, xs_pad, device, mask): | 
					
						
						|  | max_audio_length = xs_pad.shape[1] | 
					
						
						|  | batch_size = xs_pad.shape[0] | 
					
						
						|  | enc_streaming_mask = self._streaming_mask( | 
					
						
						|  | max_audio_length, batch_size, self.chunk_size, self.left_chunk | 
					
						
						|  | ) | 
					
						
						|  | enc_streaming_mask = enc_streaming_mask.to(device) | 
					
						
						|  | if mask is None: | 
					
						
						|  | return enc_streaming_mask | 
					
						
						|  |  | 
					
						
						|  | feature_lens = mask.sum(1) | 
					
						
						|  | padding_length = feature_lens | 
					
						
						|  | pad_mask = ( | 
					
						
						|  | torch.arange(0, max_audio_length, device=device).expand(padding_length.size(0), -1) | 
					
						
						|  | < padding_length.unsqueeze(1) | 
					
						
						|  | ) | 
					
						
						|  | pad_mask = pad_mask.unsqueeze(1) | 
					
						
						|  | pad_mask = pad_mask & enc_streaming_mask | 
					
						
						|  | return pad_mask | 
					
						
						|  |  | 
					
						
						|  | @torch.jit.ignore | 
					
						
						|  | def forward(self, xs_pad, masks): | 
					
						
						|  | """Conformer Forward function | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | xs_pad: torch.Tensor | 
					
						
						|  | input tensor | 
					
						
						|  | masks: torch.Tensor | 
					
						
						|  | post-embedding input lengths | 
					
						
						|  | """ | 
					
						
						|  | xs_pad = self.encoder_embedding(xs_pad) | 
					
						
						|  | input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(xs_pad, masks) | 
					
						
						|  |  | 
					
						
						|  | unfolded = False | 
					
						
						|  | ori_bz, seq_len, D = input_tensor.shape | 
					
						
						|  | max_seq_len = 500 | 
					
						
						|  | if seq_len > max_seq_len: | 
					
						
						|  |  | 
					
						
						|  | unfolded = True | 
					
						
						|  |  | 
					
						
						|  | if seq_len % max_seq_len > 0: | 
					
						
						|  | chunk_pad_size = max_seq_len - (seq_len % max_seq_len) | 
					
						
						|  | else: | 
					
						
						|  | chunk_pad_size = 0 | 
					
						
						|  | if chunk_pad_size > 0: | 
					
						
						|  | input_tensor_pad = F.pad(input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0) | 
					
						
						|  | input_tensor = input_tensor_pad.to(input_tensor.device) | 
					
						
						|  |  | 
					
						
						|  | input_tensor = unfold_tensor(input_tensor, max_seq_len) | 
					
						
						|  | if masks is not None: | 
					
						
						|  |  | 
					
						
						|  | subsampled_pad_mask = masks.squeeze(1) | 
					
						
						|  | extra_padded_subsamlped_pad_mask = F.pad(subsampled_pad_mask, (0, chunk_pad_size), "constant", False) | 
					
						
						|  | extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() | 
					
						
						|  | masks_unfold = unfold_tensor(extra_padded_subsamlped_pad_mask, max_seq_len) | 
					
						
						|  | masks_unfold = masks_unfold.squeeze(-1).bool() | 
					
						
						|  | else: | 
					
						
						|  | masks_unfold = None | 
					
						
						|  | hs_mask = self.calculate_hs_mask(input_tensor, input_tensor.device, masks_unfold) | 
					
						
						|  | layer_emb = None | 
					
						
						|  |  | 
					
						
						|  | relative_attention_bias = self.init_relative_attention_bias(input_tensor) | 
					
						
						|  |  | 
					
						
						|  | _simplified_path = ( | 
					
						
						|  | self.extra_layer_output_idx == -1 | 
					
						
						|  | and relative_attention_bias is None | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if _simplified_path: | 
					
						
						|  | input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask) | 
					
						
						|  | else: | 
					
						
						|  | for i, layer in enumerate(self.encoders): | 
					
						
						|  | input_tensor, _, _, _ = layer( | 
					
						
						|  | input_tensor, | 
					
						
						|  | pos_k, | 
					
						
						|  | pos_v, | 
					
						
						|  | hs_mask, | 
					
						
						|  | relative_attention_bias=relative_attention_bias, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if i == self.extra_layer_output_idx: | 
					
						
						|  | layer_emb = input_tensor | 
					
						
						|  | if unfolded: | 
					
						
						|  | embed_dim = input_tensor.shape[-1] | 
					
						
						|  | input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) | 
					
						
						|  |  | 
					
						
						|  | if chunk_pad_size > 0: | 
					
						
						|  | input_tensor = input_tensor[:, :-chunk_pad_size, :] | 
					
						
						|  | return input_tensor, masks | 
					
						
						|  |  | 
					
						
						|  | def gradient_checkpointing_enable(self): | 
					
						
						|  | pass | 
					
						
						|  |  |