Spaces:
Paused
Paused
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from cube3d.model.transformers.cache import Cache | |
| from cube3d.model.transformers.norm import LayerNorm, RMSNorm | |
| from cube3d.model.transformers.rope import scaled_dot_product_attention_with_rotary_emb | |
| class SwiGLUMLP(nn.Module): | |
| def __init__(self, embed_dim, hidden_dim, bias=True, **kwargs): | |
| """ | |
| A PyTorch implementation of the SwiGLU (Swish-Gated Linear Unit) MLP layer. | |
| This module consists of three linear projections: `gate_proj`, `up_proj`, and `down_proj`. | |
| It applies the SwiGLU activation function, which combines the Swish activation with a gating mechanism, | |
| followed by a projection back to the original embedding dimension. | |
| Args: | |
| embed_dim (int): The dimensionality of the input embeddings. | |
| hidden_dim (int): The dimensionality of the hidden layer. | |
| bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True. | |
| **kwargs: Additional keyword arguments (currently unused). | |
| """ | |
| super().__init__() | |
| self.gate_proj = nn.Linear(embed_dim, hidden_dim, bias=bias) | |
| self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=bias) | |
| self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=bias) | |
| # Ignore copy | |
| def forward(self, x): | |
| """ | |
| Applies a forward pass. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| Returns: | |
| torch.Tensor: The output tensor after applying the forward pass. | |
| """ | |
| down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| return down_proj | |
| class SelfAttentionWithRotaryEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_heads: int, | |
| bias: bool = True, | |
| eps: float = 1e-6, | |
| ): | |
| """ | |
| A PyTorch module implementing self-attention with rotary embeddings. | |
| Args: | |
| embed_dim (int): The dimensionality of the input embeddings. | |
| num_heads (int): The number of attention heads. | |
| bias (bool, optional): Whether to include bias terms in the linear projections. Defaults to True. | |
| eps (float, optional): A small value added for numerical stability in normalization. Defaults to 1e-6. | |
| """ | |
| super().__init__() | |
| assert embed_dim % num_heads == 0 | |
| self.num_heads = num_heads | |
| # key, query, value projections for all heads, but in a batch | |
| self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=False) | |
| self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| # output projection | |
| self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| head_dim = embed_dim // num_heads | |
| self.q_norm = RMSNorm(head_dim) | |
| self.k_norm = RMSNorm(head_dim) | |
| def forward( | |
| self, | |
| x, | |
| freqs_cis: torch.Tensor, | |
| attn_mask=None, | |
| is_causal: bool = False, | |
| kv_cache: Optional[Cache] = None, | |
| curr_pos_id: Optional[torch.Tensor] = None, | |
| decode: bool = False, | |
| ): | |
| """ | |
| Forward pass for the SelfAttentionWithRotaryEmbedding instance. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| freqs_cis (torch.Tensor): Precomputed rotary positional embeddings. | |
| attn_mask (Optional[torch.Tensor], optional): Attention mask to apply during self-attention. Defaults to None. | |
| is_causal (bool, optional): Whether to apply causal masking for autoregressive decoding. Defaults to False. | |
| kv_cache (Optional[Cache], optional): Cache object for storing key and value states for decoding. Defaults to None. | |
| curr_pos_id (Optional[torch.Tensor], optional): Current position indices for decoding. Required if `decode` is True. Defaults to None. | |
| decode (bool, optional): Whether the model is in decoding mode. Defaults to False. | |
| Returns: | |
| torch.Tensor: Output tensor after applying self-attention and projection. | |
| """ | |
| # batch size, sequence length, embedding dim | |
| b, l, d = x.shape | |
| # compute q, k, v and then split per q, k, v | |
| q, k = self.c_qk(x).chunk(2, dim=-1) | |
| v = self.c_v(x) | |
| # split per head | |
| q = q.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs) | |
| k = k.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs) | |
| v = v.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs) | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| if kv_cache is not None: | |
| if not decode: | |
| kv_cache.key_states[:, :, : k.shape[2], :].copy_(k) | |
| kv_cache.value_states[:, :, : k.shape[2], :].copy_(v) | |
| else: | |
| assert curr_pos_id is not None | |
| kv_cache.key_states.index_copy_(2, curr_pos_id, k) | |
| kv_cache.value_states.index_copy_(2, curr_pos_id, v) | |
| k = kv_cache.key_states | |
| v = kv_cache.value_states | |
| # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) | |
| # efficient attention using Flash Attention CUDA kernels | |
| y = scaled_dot_product_attention_with_rotary_emb( | |
| q, | |
| k, | |
| v, | |
| freqs_cis=freqs_cis, | |
| attn_mask=attn_mask, | |
| curr_pos_id=curr_pos_id if decode else None, | |
| is_causal=is_causal, | |
| ) | |
| y = ( | |
| y.transpose(1, 2).contiguous().view(b, l, d) | |
| ) # re-assemble all head outputs side by side | |
| # output projection | |
| y = self.c_proj(y) | |
| return y | |
| class DecoderLayerWithRotaryEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_heads: int, | |
| bias: bool = True, | |
| eps: float = 1e-6, | |
| ) -> None: | |
| """ | |
| Initializes the transformer model with rotary embeddings. | |
| Args: | |
| embed_dim (int): The dimensionality of the embedding space. | |
| num_heads (int): The number of attention heads. | |
| bias (bool, optional): Whether to include bias terms in the layers. Defaults to True. | |
| eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1e-6. | |
| """ | |
| super().__init__() | |
| self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps) | |
| self.attn = SelfAttentionWithRotaryEmbedding( | |
| embed_dim, num_heads=num_heads, bias=bias, eps=eps | |
| ) | |
| self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps) | |
| self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias) | |
| def from_config(cls, cfg): | |
| """ | |
| Create an instance of the class using the provided configuration. | |
| Args: | |
| cfg: A configuration object containing the following attributes: | |
| - n_embd (int): The size of the embedding dimension. | |
| - n_head (int): The number of attention heads. | |
| - bias (bool): Whether to include a bias term. | |
| - eps (float): A small value added for numerical stability. | |
| Returns: | |
| An instance of the class initialized with the specified configuration. | |
| """ | |
| return cls( | |
| cfg.n_embd, | |
| num_heads=cfg.n_head, | |
| bias=cfg.bias, | |
| eps=cfg.eps, | |
| ) | |
| def forward( | |
| self, | |
| x, | |
| freqs_cis: torch.Tensor, | |
| attn_mask=None, | |
| is_causal: bool = True, | |
| kv_cache: Optional[Cache] = None, | |
| curr_pos_id: Optional[torch.Tensor] = None, | |
| decode: bool = False, | |
| ): | |
| """ | |
| Forward pass for the transformer model. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| freqs_cis (torch.Tensor): Precomputed sinusoidal positional encodings. | |
| attn_mask (Optional[torch.Tensor], optional): Attention mask to apply during self-attention. | |
| Defaults to None. | |
| is_causal (bool, optional): Whether to apply causal masking for autoregressive decoding. | |
| Defaults to True. | |
| kv_cache (Optional[Cache], optional): Key-value cache for efficient decoding. | |
| Defaults to None. | |
| curr_pos_id (Optional[torch.Tensor], optional): Current position IDs for decoding. | |
| Defaults to None. | |
| decode (bool, optional): Whether the model is in decoding mode. | |
| Defaults to False. | |
| Returns: | |
| torch.Tensor: Output tensor. | |
| """ | |
| out = self.attn( | |
| self.ln_1(x), | |
| freqs_cis=freqs_cis, | |
| attn_mask=attn_mask, | |
| is_causal=is_causal, | |
| kv_cache=kv_cache, | |
| curr_pos_id=curr_pos_id, | |
| decode=decode, | |
| ) | |
| x = x + out | |
| x = x + self.mlp(self.ln_2(x)) | |
| return x | |