Spaces:
Build error
Build error
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| from torch import nn | |
| from cube3d.model.transformers.cache import Cache | |
| from cube3d.model.transformers.dual_stream_attention import ( | |
| DualStreamDecoderLayerWithRotaryEmbedding, | |
| ) | |
| from cube3d.model.transformers.norm import LayerNorm | |
| from cube3d.model.transformers.roformer import DecoderLayerWithRotaryEmbedding | |
| from cube3d.model.transformers.rope import precompute_freqs_cis | |
| class DualStreamRoformer(nn.Module): | |
| class Config: | |
| checkpoint_path: str = "" | |
| n_layer: int = 12 | |
| n_single_layer: int = 0 | |
| rope_theta: float = 1000 | |
| n_head: int = 16 | |
| n_embd: int = 2048 | |
| bias: bool = False # bias in Linears and LayerNorms | |
| eps: float = 1e-6 # Norm eps | |
| shape_model_vocab_size: int = 4096 | |
| shape_model_embed_dim: int = 16 | |
| text_model_embed_dim: int = 512 | |
| use_pooled_text_embed: bool = False | |
| encoder_with_cls_token: bool = True | |
| def __init__(self, cfg: Config) -> None: | |
| """ | |
| Initializes the DualStreamRoFormer model. | |
| Args: | |
| cfg (Config): Configuration object containing model parameters. | |
| Attributes: | |
| cfg (Config): Stores the configuration object. | |
| text_proj (nn.Linear): Linear layer to project text model embeddings to the desired embedding dimension. | |
| shape_proj (nn.Linear, optional): Linear layer to project shape model embeddings to the desired embedding | |
| dimension | |
| vocab_size (int): Vocabulary size for the shape model, including special tokens. | |
| shape_bos_id (int): Token ID for the beginning-of-sequence (BOS) token for the shape model. | |
| shape_eos_id (int): Token ID for the end-of-sequence (EOS) token for the shape model. | |
| padding_id (int): Token ID for the padding token. | |
| transformer (nn.ModuleDict): Dictionary containing the following components: | |
| - wte (nn.Embedding): Embedding layer for the vocabulary. | |
| - dual_blocks (nn.ModuleList): List of dual-stream decoder layers with rotary embeddings. | |
| - single_blocks (nn.ModuleList): List of single-stream decoder layers with rotary embeddings. | |
| - ln_f (LayerNorm): Layer normalization applied to the final output. | |
| lm_head (nn.Linear): Linear layer mapping the final embeddings to the vocabulary size for language modeling. | |
| """ | |
| super().__init__() | |
| self.cfg = cfg | |
| self.text_proj = nn.Linear( | |
| in_features=self.cfg.text_model_embed_dim, | |
| out_features=self.cfg.n_embd, | |
| bias=self.cfg.bias, | |
| ) | |
| self.shape_proj = nn.Linear(self.cfg.shape_model_embed_dim, self.cfg.n_embd) | |
| self.vocab_size = self.cfg.shape_model_vocab_size | |
| def add_special_token(): | |
| token_id = self.vocab_size | |
| self.vocab_size += 1 | |
| return token_id | |
| self.shape_bos_id = add_special_token() | |
| self.shape_eos_id = add_special_token() | |
| self.padding_id = add_special_token() | |
| self.transformer = nn.ModuleDict( | |
| dict( | |
| wte=nn.Embedding( | |
| self.vocab_size, | |
| self.cfg.n_embd, | |
| padding_idx=self.padding_id, | |
| ), | |
| dual_blocks=nn.ModuleList( | |
| [ | |
| DualStreamDecoderLayerWithRotaryEmbedding.from_config( | |
| self.cfg, cond_pre_only=(i == self.cfg.n_layer - 1) | |
| ) | |
| for i in range(self.cfg.n_layer) | |
| ] | |
| ), | |
| single_blocks=nn.ModuleList( | |
| [ | |
| DecoderLayerWithRotaryEmbedding.from_config(self.cfg) | |
| for _ in range(self.cfg.n_single_layer) | |
| ] | |
| ), | |
| ln_f=LayerNorm( | |
| self.cfg.n_embd, elementwise_affine=False, eps=self.cfg.eps | |
| ), | |
| ) | |
| ) | |
| self.lm_head = nn.Linear(self.cfg.n_embd, self.vocab_size, bias=False) | |
| def encode_text(self, text_embed): | |
| """ | |
| Encodes the given text embeddings by projecting them through a linear transformation. | |
| Args: | |
| text_embed (torch.Tensor): A tensor representing the text embeddings to be encoded. | |
| Returns: | |
| torch.Tensor: The projected text embeddings after applying the linear transformation. | |
| """ | |
| return self.text_proj(text_embed) | |
| def encode_token(self, tokens): | |
| """ | |
| Encodes the input tokens using the word token embedding layer of the transformer model. | |
| Args: | |
| tokens (torch.Tensor): A tensor containing the input tokens to be encoded. | |
| Returns: | |
| torch.Tensor: A tensor containing the encoded token embeddings. | |
| """ | |
| return self.transformer.wte(tokens) | |
| def init_kv_cache( | |
| self, | |
| batch_size: int, | |
| cond_len: int, | |
| max_shape_tokens: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| ) -> list[Cache]: | |
| """ | |
| Initializes the key-value cache for the transformer model. | |
| This method creates a list of `Cache` objects to store the key and value | |
| states for both dual-stream and single-stream transformer blocks. The | |
| cache is pre-allocated with zeros and is used to optimize the computation | |
| of attention mechanisms during model inference. | |
| Args: | |
| batch_size (int): The batch size for the input data. | |
| cond_len (int): The length of the conditioning sequence. | |
| max_shape_tokens (int): The maximum number of tokens in the shape sequence. | |
| dtype (torch.dtype): The data type for the tensors (e.g., torch.float32). | |
| device (torch.device): The device on which the tensors will be allocated | |
| (e.g., torch.device('cuda') or torch.device('cpu')). | |
| Returns: | |
| list[Cache]: A list of `Cache` objects containing pre-allocated key and | |
| value states for each transformer block. | |
| """ | |
| num_heads = self.cfg.n_head | |
| max_all_tokens = cond_len + max_shape_tokens | |
| per_head_dim = self.cfg.n_embd // num_heads | |
| kv_cache = [ | |
| Cache( | |
| key_states=torch.zeros( | |
| (batch_size, num_heads, max_all_tokens, per_head_dim), | |
| dtype=dtype, | |
| device=device, | |
| ), | |
| value_states=torch.zeros( | |
| (batch_size, num_heads, max_all_tokens, per_head_dim), | |
| dtype=dtype, | |
| device=device, | |
| ), | |
| ) | |
| for _ in range(len(self.transformer.dual_blocks)) | |
| ] | |
| kv_cache += [ | |
| Cache( | |
| key_states=torch.zeros( | |
| (batch_size, num_heads, max_shape_tokens, per_head_dim), | |
| dtype=dtype, | |
| device=device, | |
| ), | |
| value_states=torch.zeros( | |
| (batch_size, num_heads, max_shape_tokens, per_head_dim), | |
| dtype=dtype, | |
| device=device, | |
| ), | |
| ) | |
| for _ in range(len(self.transformer.single_blocks)) | |
| ] | |
| return kv_cache | |
| def forward( | |
| self, | |
| embed: torch.Tensor, | |
| cond: torch.Tensor, | |
| kv_cache: Optional[list[Cache]] = None, | |
| curr_pos_id: Optional[torch.Tensor] = None, | |
| decode: bool = False, | |
| ): | |
| """ | |
| Forward pass for the dual-stream RoFormer model. | |
| Args: | |
| embed (torch.Tensor): The input embedding tensor. | |
| cond (torch.Tensor): The conditioning tensor. | |
| kv_cache (Optional[list[Cache]]): A list of key-value caches for each layer, used for decoding. Default is None. | |
| curr_pos_id (Optional[torch.Tensor]): The current position ID tensor of shape (batch_size,). Required if `decode` is True. Default is None. | |
| decode (bool): Whether the model is in decoding mode. Default is False. | |
| Returns: | |
| torch.Tensor: The output logits tensor. | |
| """ | |
| b, l = embed.shape[:2] | |
| s = cond.shape[1] | |
| device = embed.device | |
| attn_mask = torch.tril( | |
| torch.ones(s + l, s + l, dtype=torch.bool, device=device) | |
| ) | |
| position_ids = torch.arange(l, dtype=torch.long, device=device) # shape (t) | |
| position_ids = position_ids.unsqueeze_(0).expand(b, -1) | |
| s_freqs_cis = precompute_freqs_cis( | |
| dim=self.cfg.n_embd // self.cfg.n_head, | |
| t=position_ids, | |
| theta=self.cfg.rope_theta, | |
| ) | |
| position_ids = torch.cat( | |
| [ | |
| torch.zeros([b, s], dtype=torch.long, device=position_ids.device), | |
| position_ids, | |
| ], | |
| dim=1, | |
| ) | |
| d_freqs_cis = precompute_freqs_cis( | |
| dim=self.cfg.n_embd // self.cfg.n_head, | |
| t=position_ids, | |
| theta=self.cfg.rope_theta, | |
| ) | |
| if kv_cache is not None and decode: | |
| assert curr_pos_id is not None | |
| embed = embed[:, curr_pos_id, :] | |
| h = embed | |
| c = cond | |
| layer_idx = 0 | |
| for block in self.transformer.dual_blocks: | |
| h, c = block( | |
| h, | |
| c=c, | |
| freqs_cis=d_freqs_cis, | |
| attn_mask=attn_mask, | |
| is_causal=True, | |
| kv_cache=kv_cache[layer_idx] if kv_cache is not None else None, | |
| curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None, | |
| decode=decode, | |
| ) | |
| layer_idx += 1 | |
| for block in self.transformer.single_blocks: | |
| h = block( | |
| h, | |
| freqs_cis=s_freqs_cis, | |
| attn_mask=None, | |
| is_causal=True, | |
| kv_cache=kv_cache[layer_idx] if kv_cache is not None else None, | |
| curr_pos_id=curr_pos_id, | |
| decode=decode, | |
| ) | |
| layer_idx += 1 | |
| # Normalization | |
| h = self.transformer.ln_f(h) | |
| logits = self.lm_head(h) | |
| return logits | |