Spaces:
Runtime error
Runtime error
| """ NOTE(Mddct): This file is experimental and is used to export paraformer | |
| """ | |
| import math | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.utils.checkpoint as ckpt | |
| from wenet.paraformer.attention import (DummyMultiHeadSANM, | |
| MultiHeadAttentionCross, | |
| MultiHeadedAttentionSANM) | |
| from wenet.paraformer.embedding import ParaformerPositinoalEncoding | |
| from wenet.paraformer.subsampling import IdentitySubsampling | |
| from wenet.transformer.encoder import BaseEncoder | |
| from wenet.transformer.decoder import TransformerDecoder | |
| from wenet.transformer.decoder_layer import DecoderLayer | |
| from wenet.transformer.encoder_layer import TransformerEncoderLayer | |
| from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward | |
| from wenet.utils.mask import make_non_pad_mask | |
| class LFR(torch.nn.Module): | |
| def __init__(self, m: int = 7, n: int = 6) -> None: | |
| """ | |
| Actually, this implements stacking frames and skipping frames. | |
| if m = 1 and n = 1, just return the origin features. | |
| if m = 1 and n > 1, it works like skipping. | |
| if m > 1 and n = 1, it works like stacking but only support right frames. | |
| if m > 1 and n > 1, it works like LFR. | |
| """ | |
| super().__init__() | |
| self.m = m | |
| self.n = n | |
| self.left_padding_nums = math.ceil((self.m - 1) // 2) | |
| def forward(self, input: torch.Tensor, | |
| input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| orign_type = input_lens.dtype | |
| input_lens = input_lens.to(torch.int64) | |
| B, _, D = input.size() | |
| n_lfr = torch.ceil(input_lens / self.n).to(input_lens.dtype) | |
| # right_padding_nums >= 0 | |
| prepad_nums = input_lens + self.left_padding_nums | |
| right_padding_nums = torch.where( | |
| self.m >= (prepad_nums - self.n * (n_lfr - 1)), | |
| self.m - (prepad_nums - self.n * (n_lfr - 1)), | |
| 0, | |
| ) | |
| T_all = self.left_padding_nums + input_lens + right_padding_nums | |
| new_len = T_all // self.n | |
| T_all_max = T_all.max().int() | |
| tail_frames_index = (input_lens - 1).view(B, 1, 1).repeat(1, 1, | |
| D) # [B,1,D] | |
| tail_frames = torch.gather(input, 1, tail_frames_index) | |
| tail_frames = tail_frames.repeat(1, right_padding_nums.max().int(), 1) | |
| head_frames = input[:, 0:1, :].repeat(1, self.left_padding_nums, 1) | |
| # stack | |
| input = torch.cat([head_frames, input, tail_frames], dim=1) | |
| index = torch.arange(T_all_max, | |
| device=input.device, | |
| dtype=input_lens.dtype).unsqueeze(0).repeat( | |
| B, 1) # [B, T_all_max] | |
| # [B, T_all_max] | |
| index_mask = index < (self.left_padding_nums + input_lens).unsqueeze(1) | |
| tail_index_mask = torch.logical_not( | |
| index >= (T_all.unsqueeze(1))) & index_mask | |
| tail = torch.ones(T_all_max, | |
| dtype=input_lens.dtype, | |
| device=input.device).unsqueeze(0).repeat(B, 1) * ( | |
| T_all_max - 1) # [B, T_all_max] | |
| indices = torch.where(torch.logical_or(index_mask, tail_index_mask), | |
| index, tail) | |
| input = torch.gather(input, 1, indices.unsqueeze(2).repeat(1, 1, D)) | |
| input = input.unfold(1, self.m, step=self.n).transpose(2, 3) | |
| # new len | |
| new_len = new_len.to(orign_type) | |
| return input.reshape(B, -1, D * self.m), new_len | |
| class PositionwiseFeedForwardDecoderSANM(torch.nn.Module): | |
| """Positionwise feed forward layer. | |
| Args: | |
| idim (int): Input dimenstion. | |
| hidden_units (int): The number of hidden units. | |
| dropout_rate (float): Dropout rate. | |
| """ | |
| def __init__(self, | |
| idim, | |
| hidden_units, | |
| dropout_rate, | |
| adim=None, | |
| activation=torch.nn.ReLU()): | |
| """Construct an PositionwiseFeedForward object.""" | |
| super(PositionwiseFeedForwardDecoderSANM, self).__init__() | |
| self.w_1 = torch.nn.Linear(idim, hidden_units) | |
| self.w_2 = torch.nn.Linear(hidden_units, | |
| idim if adim is None else adim, | |
| bias=False) | |
| self.dropout = torch.nn.Dropout(dropout_rate) | |
| self.activation = activation | |
| self.norm = torch.nn.LayerNorm(hidden_units) | |
| def forward(self, x): | |
| """Forward function.""" | |
| return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x))))) | |
| class AliParaformerEncoderLayer(TransformerEncoderLayer): | |
| def __init__(self, | |
| size: int, | |
| self_attn: torch.nn.Module, | |
| feed_forward: torch.nn.Module, | |
| dropout_rate: float, | |
| normalize_before: bool = True, | |
| in_size: int = 256): | |
| """ Resize input in_size to size | |
| """ | |
| super().__init__(size, self_attn, feed_forward, dropout_rate, | |
| normalize_before) | |
| self.in_size = in_size | |
| self.size = size | |
| del self.norm1 | |
| self.norm1 = torch.nn.LayerNorm(in_size) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| mask: torch.Tensor, | |
| pos_emb: Optional[torch.Tensor] = None, | |
| mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), | |
| att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), | |
| cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| residual = x | |
| if self.normalize_before: | |
| x = self.norm1(x) | |
| x_att, new_att_cache = self.self_attn( | |
| x, | |
| x, | |
| x, | |
| mask, | |
| cache=att_cache, | |
| mask_pad=mask_pad, | |
| ) | |
| if self.in_size == self.size: | |
| x = residual + self.dropout(x_att) | |
| else: | |
| x = self.dropout(x_att) | |
| if not self.normalize_before: | |
| x = self.norm1(x) | |
| residual = x | |
| if self.normalize_before: | |
| x = self.norm2(x) | |
| x = residual + self.dropout(self.feed_forward(x)) | |
| if not self.normalize_before: | |
| x = self.norm2(x) | |
| fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) | |
| return x, mask, new_att_cache, fake_cnn_cache | |
| class SanmEncoder(BaseEncoder): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int = 256, | |
| attention_heads: int = 4, | |
| linear_units: int = 2048, | |
| num_blocks: int = 6, | |
| dropout_rate: float = 0.1, | |
| positional_dropout_rate: float = 0.1, | |
| attention_dropout_rate: float = 0, | |
| input_layer: str = "conv2d", | |
| pos_enc_layer_type: str = "abs_pos", | |
| normalize_before: bool = True, | |
| static_chunk_size: int = 0, | |
| use_dynamic_chunk: bool = False, | |
| global_cmvn: torch.nn.Module = None, | |
| use_dynamic_left_chunk: bool = False, | |
| kernel_size: int = 11, | |
| sanm_shfit: int = 0, | |
| gradient_checkpointing: bool = False, | |
| ): | |
| super().__init__(input_size, | |
| output_size, | |
| attention_heads, | |
| linear_units, | |
| num_blocks, | |
| dropout_rate, | |
| positional_dropout_rate, | |
| attention_dropout_rate, | |
| input_layer, | |
| pos_enc_layer_type, | |
| normalize_before, | |
| static_chunk_size, | |
| use_dynamic_chunk, | |
| global_cmvn, | |
| use_dynamic_left_chunk, | |
| gradient_checkpointing=gradient_checkpointing) | |
| del self.embed | |
| self.embed = IdentitySubsampling( | |
| input_size, | |
| output_size, | |
| dropout_rate, | |
| ParaformerPositinoalEncoding(input_size, | |
| output_size, | |
| positional_dropout_rate, | |
| max_len=5000), | |
| ) | |
| encoder_selfattn_layer = MultiHeadedAttentionSANM | |
| encoder_selfattn_layer_args0 = ( | |
| attention_heads, | |
| input_size, | |
| output_size, | |
| attention_dropout_rate, | |
| kernel_size, | |
| sanm_shfit, | |
| ) | |
| encoder_selfattn_layer_args = ( | |
| attention_heads, | |
| output_size, | |
| output_size, | |
| attention_dropout_rate, | |
| kernel_size, | |
| sanm_shfit, | |
| ) | |
| self.encoders0 = torch.nn.ModuleList([ | |
| AliParaformerEncoderLayer( | |
| output_size, | |
| encoder_selfattn_layer(*encoder_selfattn_layer_args0), | |
| PositionwiseFeedForward(output_size, linear_units, | |
| dropout_rate), | |
| dropout_rate, | |
| normalize_before, | |
| in_size=input_size, | |
| ) | |
| ]) | |
| self.encoders = torch.nn.ModuleList([ | |
| AliParaformerEncoderLayer( | |
| output_size, | |
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |
| PositionwiseFeedForward( | |
| output_size, | |
| linear_units, | |
| dropout_rate, | |
| ), | |
| dropout_rate, | |
| normalize_before, | |
| in_size=output_size) for _ in range(num_blocks - 1) | |
| ]) | |
| if self.normalize_before: | |
| self.after_norm = torch.nn.LayerNorm(output_size) | |
| def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, | |
| pos_emb: torch.Tensor, | |
| mask_pad: torch.Tensor) -> torch.Tensor: | |
| for layer in self.encoders0: | |
| xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) | |
| for layer in self.encoders: | |
| xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) | |
| return xs | |
| def forward_layers_checkpointed(self, xs: torch.Tensor, | |
| chunk_masks: torch.Tensor, | |
| pos_emb: torch.Tensor, | |
| mask_pad: torch.Tensor) -> torch.Tensor: | |
| for layer in self.encoders0: | |
| xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) | |
| for layer in self.encoders: | |
| xs, _, _, _ = ckpt.checkpoint(layer.__call__, | |
| xs, | |
| chunk_masks, | |
| pos_emb, | |
| mask_pad, | |
| use_reentrant=False) | |
| return xs | |
| class _Decoders3(torch.nn.Module): | |
| """Paraformer has a decoder3""" | |
| def __init__(self, hidden: int, pos_clss: torch.nn.Module) -> None: | |
| super().__init__() | |
| self.feed_forward = pos_clss | |
| self.norm1 = torch.nn.LayerNorm(hidden) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.feed_forward(self.norm1(x)) | |
| class SanmDecoderLayer(DecoderLayer): | |
| def __init__(self, | |
| size: int, | |
| self_attn: Optional[torch.nn.Module], | |
| src_attn: Optional[torch.nn.Module], | |
| feed_forward: torch.nn.Module, | |
| dropout_rate: float, | |
| normalize_before: bool = True): | |
| super().__init__(size, self_attn, src_attn, feed_forward, dropout_rate, | |
| normalize_before) | |
| # NOTE(Mddct): ali-Paraformer need eps=1e-12 | |
| self.norm1 = torch.nn.LayerNorm(size, eps=1e-12) | |
| self.norm2 = torch.nn.LayerNorm(size, eps=1e-12) | |
| self.norm3 = torch.nn.LayerNorm(size, eps=1e-12) | |
| def forward( | |
| self, | |
| tgt: torch.Tensor, | |
| tgt_mask: torch.Tensor, | |
| memory: torch.Tensor, | |
| memory_mask: torch.Tensor, | |
| cache: Optional[torch.Tensor] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| residual = tgt | |
| if self.normalize_before: | |
| tgt = self.norm1(tgt) | |
| tgt = self.feed_forward(tgt) | |
| if cache is None: | |
| tgt_q = tgt | |
| tgt_q_mask = tgt_mask | |
| else: | |
| # compute only the last frame query keeping dim: max_time_out -> 1 | |
| assert cache.shape == ( | |
| tgt.shape[0], | |
| tgt.shape[1] - 1, | |
| self.size, | |
| ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" | |
| tgt_q = tgt[:, -1:, :] | |
| residual = residual[:, -1:, :] | |
| tgt_q_mask = tgt_mask[:, -1:, :] | |
| x = tgt | |
| if self.self_attn is not None: | |
| if self.normalize_before: | |
| tgt = self.norm2(tgt) | |
| tgt_q = tgt | |
| x = self.self_attn(tgt_q, | |
| tgt, | |
| tgt, | |
| tgt_q_mask, | |
| mask_pad=tgt_q_mask)[0] | |
| x = residual + self.dropout(x) | |
| if self.src_attn is not None: | |
| residual = x | |
| if self.normalize_before: | |
| x = self.norm3(x) | |
| x = residual + self.dropout( | |
| self.src_attn( | |
| x, memory, memory, memory_mask, mask_pad=memory_mask)[0]) | |
| return x, tgt_mask, memory, memory_mask | |
| class SanmDecoder(TransformerDecoder): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| encoder_output_size: int, | |
| attention_heads: int = 4, | |
| linear_units: int = 2048, | |
| num_blocks: int = 6, | |
| dropout_rate: float = 0.1, | |
| positional_dropout_rate: float = 0.1, | |
| self_attention_dropout_rate: float = 0, | |
| src_attention_dropout_rate: float = 0, | |
| input_layer: str = "embed", | |
| use_output_layer: bool = True, | |
| normalize_before: bool = True, | |
| src_attention: bool = True, | |
| att_layer_num: int = 16, | |
| kernel_size: int = 11, | |
| sanm_shfit: int = 0, | |
| gradient_checkpointing: bool = False, | |
| ): | |
| super().__init__(vocab_size, | |
| encoder_output_size, | |
| attention_heads, | |
| linear_units, | |
| num_blocks, | |
| dropout_rate, | |
| positional_dropout_rate, | |
| self_attention_dropout_rate, | |
| src_attention_dropout_rate, | |
| input_layer, | |
| use_output_layer, | |
| normalize_before, | |
| src_attention, | |
| gradient_checkpointing=gradient_checkpointing) | |
| del self.embed, self.decoders | |
| self.decoders = torch.nn.ModuleList([ | |
| SanmDecoderLayer( | |
| encoder_output_size, | |
| DummyMultiHeadSANM(attention_heads, encoder_output_size, | |
| encoder_output_size, dropout_rate, | |
| kernel_size, sanm_shfit), | |
| MultiHeadAttentionCross(attention_heads, encoder_output_size, | |
| encoder_output_size, dropout_rate, | |
| kernel_size, sanm_shfit, | |
| encoder_output_size), | |
| PositionwiseFeedForwardDecoderSANM(encoder_output_size, | |
| linear_units, dropout_rate), | |
| dropout_rate, | |
| normalize_before, | |
| ) for _ in range(att_layer_num) | |
| ]) | |
| # NOTE(Mddct): att_layer_num == num_blocks in released pararformer model | |
| assert att_layer_num == num_blocks | |
| # NOTE(Mddct): Paraformer has a deocder3 | |
| self.decoders3 = torch.nn.ModuleList([ | |
| _Decoders3( | |
| encoder_output_size, | |
| PositionwiseFeedForwardDecoderSANM(encoder_output_size, | |
| linear_units, dropout_rate)) | |
| ]) | |
| def forward( | |
| self, | |
| encoder_out: torch.Tensor, | |
| encoder_out_mask: torch.Tensor, | |
| sematic_embeds: torch.Tensor, | |
| ys_pad_lens: torch.Tensor, | |
| r_ys_in_pad: torch.Tensor = torch.empty(0), | |
| reverse_weight: float = 0.0, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| ys_pad_mask = make_non_pad_mask(ys_pad_lens).unsqueeze(1) | |
| x = sematic_embeds | |
| if self.gradient_checkpointing and self.training: | |
| x = self.forward_layers_checkpointed(x, ys_pad_mask, encoder_out, | |
| encoder_out_mask) | |
| else: | |
| x = self.forward_layers(x, ys_pad_mask, encoder_out, | |
| encoder_out_mask) | |
| if self.normalize_before: | |
| x = self.after_norm(x) | |
| if self.output_layer is not None: | |
| x = self.output_layer(x) | |
| return x, torch.tensor(0.0), ys_pad_lens | |
| def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor, | |
| memory: torch.Tensor, | |
| memory_mask: torch.Tensor) -> torch.Tensor: | |
| for layer in self.decoders: | |
| x, _, _, _ = layer(x, tgt_mask, memory, memory_mask) | |
| for layer in self.decoders3: | |
| x = layer(x) | |
| return x | |
| def forward_layers_checkpointed(self, x: torch.Tensor, | |
| tgt_mask: torch.Tensor, | |
| memory: torch.Tensor, | |
| memory_mask: torch.Tensor) -> torch.Tensor: | |
| for i, layer in enumerate(self.decoders): | |
| if i == 0: | |
| x, _, _, _ = layer(x, tgt_mask, memory, memory_mask) | |
| else: | |
| x, _, _, _ = ckpt.checkpoint(layer.__call__, | |
| x, | |
| tgt_mask, | |
| memory, | |
| memory_mask, | |
| use_reentrant=False) | |
| for layer in self.decoders3: | |
| x = layer(x) | |
| return x | |