Spaces:
Runtime error
Runtime error
| from functools import partial | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.utils.checkpoint as ckpt | |
| from wenet.transformer.attention import T_CACHE | |
| from wenet.transformer.encoder_layer import TransformerEncoderLayer | |
| from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES, | |
| WENET_ATTENTION_CLASSES, | |
| WENET_EMB_CLASSES, WENET_MLP_CLASSES, | |
| WENET_NORM_CLASSES) | |
| from wenet.utils.common import mask_to_bias | |
| class DecoderOnly(torch.nn.Module): | |
| def __init__( | |
| self, | |
| n_kv_head: int, | |
| head_dim: int, | |
| hidden_size: int, | |
| attention_heads: int = 4, | |
| linear_units: int = 2048, | |
| num_blocks: int = 6, | |
| dropout_rate: float = 0.1, | |
| positional_dropout_rate: float = 0.1, | |
| attention_dropout_rate: float = 0.0, | |
| normalize_before: bool = True, | |
| query_bias: bool = False, | |
| key_bias: bool = False, | |
| value_bias: bool = False, | |
| mlp_bias: bool = False, | |
| activation_type: str = "gelu", | |
| gelu_approximate: Union[str, None] = None, | |
| max_position_embeding: int = 8192, | |
| mlp_type: str = 'gated', | |
| layer_norm_type: str = 'rms_norm', | |
| norm_eps: float = 1e-5, | |
| rms_norm_offset: bool = True, | |
| selfattention_layer_type: str = "rope_abs_selfattn", | |
| use_sdpa: bool = False, | |
| gradient_checkpointing: bool = False, | |
| rope_theta: float = 10000.0, | |
| rope_style: str = 'google', | |
| scale_embed: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| assert selfattention_layer_type in ['rope_abs_selfattn'] | |
| self.pos_enc = WENET_EMB_CLASSES["rope_pos"]( | |
| hidden_size, | |
| head_dim, | |
| max_len=max_position_embeding, | |
| dropout_rate=positional_dropout_rate, | |
| rope_theta=rope_theta, | |
| scale=scale_embed) | |
| if activation_type == "gelu" and gelu_approximate is not None: | |
| activation = WENET_ACTIVATION_CLASSES['gelu']( | |
| approximate=gelu_approximate) | |
| else: | |
| activation = WENET_ACTIVATION_CLASSES[activation_type]() | |
| mlp_class = WENET_MLP_CLASSES[mlp_type] | |
| self.num_blocks = num_blocks | |
| # TODO: support lora & refactor lora | |
| self.decoders = torch.nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| hidden_size, | |
| WENET_ATTENTION_CLASSES[selfattention_layer_type]( | |
| attention_heads, | |
| hidden_size, | |
| attention_dropout_rate, | |
| query_bias, | |
| key_bias, | |
| value_bias, | |
| use_sdpa, | |
| n_kv_head, | |
| head_dim, | |
| style=rope_style), | |
| mlp_class(hidden_size, linear_units, dropout_rate, activation, | |
| mlp_bias), | |
| dropout_rate, | |
| normalize_before, | |
| layer_norm_type=layer_norm_type, | |
| norm_eps=norm_eps, | |
| rms_norm_offset=rms_norm_offset, | |
| ) for _ in range(self.num_blocks) | |
| ]) | |
| self.pre_norm = normalize_before | |
| self.final_norm: Optional[torch.nn.Module] = None | |
| if self.pre_norm: | |
| norm_class = WENET_NORM_CLASSES[layer_norm_type] | |
| if layer_norm_type == "rms_norm": | |
| norm_class = partial( | |
| norm_class, | |
| add_unit_offset=rms_norm_offset, | |
| ) | |
| self.final_norm = norm_class(hidden_size, eps=norm_eps) | |
| self.n_kv_head = n_kv_head | |
| self.head_dim = head_dim | |
| self._hidden_size = hidden_size | |
| self.use_sdpa = use_sdpa | |
| self.gradient_checkpointing = gradient_checkpointing | |
| def forward( | |
| self, | |
| input: torch.Tensor, | |
| att_mask: torch.Tensor, | |
| input_position: Union[int, torch.Tensor] = 0, | |
| kv_caches: Optional[List[T_CACHE]] = None, | |
| ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: | |
| xs, pos_emb = self.pos_enc(input, offset=input_position) | |
| if self.use_sdpa: | |
| att_mask = mask_to_bias(att_mask, xs.dtype) | |
| if self.gradient_checkpointing and self.training: | |
| xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb) | |
| else: | |
| xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb, | |
| kv_caches) | |
| if self.pre_norm and self.final_norm is not None: | |
| xs = self.final_norm(xs) | |
| return xs, kv_caches | |
| def forward_layers( | |
| self, | |
| xs: torch.Tensor, | |
| att_mask: torch.Tensor, | |
| pos_emb: torch.Tensor, | |
| kv_caches: Optional[List[T_CACHE]] = None, | |
| ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: | |
| if self.training: | |
| for (i, layer) in enumerate(self.decoders): | |
| xs, _, _, _ = layer(xs, att_mask, pos_emb) | |
| new_kv_caches = kv_caches | |
| else: | |
| assert kv_caches is not None | |
| new_kv_caches = [] | |
| for (i, layer) in enumerate(self.decoders): | |
| xs, _, new_kv_cache, _ = layer(xs, | |
| att_mask, | |
| pos_emb, | |
| att_cache=(kv_caches[i][0], | |
| kv_caches[i][1])) | |
| new_kv_caches.append(new_kv_cache) | |
| return xs, new_kv_caches | |
| def forward_layers_checkpointed(self, xs: torch.Tensor, | |
| att_mask: torch.Tensor, | |
| pos_emb: torch.Tensor) -> torch.Tensor: | |
| for layer in self.decoders: | |
| xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask, | |
| pos_emb) | |
| return xs | |
| def hidden_size(self): | |
| return self._hidden_size | |