Spaces:
Paused
Paused
| # Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team. | |
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20. | |
| # | |
| # Original file was released under Apache-2.0, with the full license text | |
| # available at https://github.com/huggingface/transformers/blob/main/LICENSE. | |
| # | |
| # This modified file is released under the same license. | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from typing import List, Optional, Tuple | |
| import torch | |
| from torch import nn | |
| from torch.nn.attention import SDPBackend, sdpa_kernel | |
| from torch.nn.attention.flex_attention import flex_attention | |
| from torch.nn.functional import scaled_dot_product_attention | |
| from transformers.utils import ModelOutput | |
| from flash_attn import flash_attn_varlen_func | |
| from modeling.qwen2.modeling_qwen2 import ( | |
| Qwen2Attention, | |
| Qwen2MLP, | |
| Qwen2PreTrainedModel, | |
| Qwen2RMSNorm, | |
| Qwen2RotaryEmbedding, | |
| apply_rotary_pos_emb, | |
| ) | |
| from modeling.qwen2.configuration_qwen2 import Qwen2Config as _Qwen2Config | |
| torch._dynamo.config.cache_size_limit = 512 | |
| torch._dynamo.config.accumulated_cache_size_limit = 4096 | |
| # flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune' | |
| flex_attention = torch.compile(flex_attention) | |
| class Qwen2Config(_Qwen2Config): | |
| r""" | |
| This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a | |
| Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration | |
| with the defaults will yield a similar configuration to that of | |
| Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). | |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |
| documentation from [`PretrainedConfig`] for more information. | |
| Args: | |
| vocab_size (`int`, *optional*, defaults to 151936): | |
| Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the | |
| `inputs_ids` passed when calling [`Qwen2Model`] | |
| hidden_size (`int`, *optional*, defaults to 4096): | |
| Dimension of the hidden representations. | |
| intermediate_size (`int`, *optional*, defaults to 22016): | |
| Dimension of the MLP representations. | |
| num_hidden_layers (`int`, *optional*, defaults to 32): | |
| Number of hidden layers in the Transformer encoder. | |
| num_attention_heads (`int`, *optional*, defaults to 32): | |
| Number of attention heads for each attention layer in the Transformer encoder. | |
| num_key_value_heads (`int`, *optional*, defaults to 32): | |
| This is the number of key_value heads that should be used to implement Grouped Query Attention. If | |
| `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if | |
| `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When | |
| converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed | |
| by meanpooling all the original heads within that group. For more details checkout [this | |
| paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. | |
| hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): | |
| The non-linear activation function (function or string) in the decoder. | |
| max_position_embeddings (`int`, *optional*, defaults to 32768): | |
| The maximum sequence length that this model might ever be used with. | |
| initializer_range (`float`, *optional*, defaults to 0.02): | |
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |
| rms_norm_eps (`float`, *optional*, defaults to 1e-06): | |
| The epsilon used by the rms normalization layers. | |
| use_cache (`bool`, *optional*, defaults to `True`): | |
| Whether or not the model should return the last key/values attentions (not used by all models). Only | |
| relevant if `config.is_decoder=True`. | |
| tie_word_embeddings (`bool`, *optional*, defaults to `False`): | |
| Whether the model's input and output word embeddings should be tied. | |
| rope_theta (`float`, *optional*, defaults to 10000.0): | |
| The base period of the RoPE embeddings. | |
| rope_scaling (`Dict`, *optional*): | |
| Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type | |
| and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value | |
| accordingly. | |
| Expected contents: | |
| `rope_type` (`str`): | |
| The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', | |
| 'llama3'], with 'default' being the original RoPE implementation. | |
| `factor` (`float`, *optional*): | |
| Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In | |
| most scaling types, a `factor` of x will enable the model to handle sequences of length x * | |
| original maximum pre-trained length. | |
| `original_max_position_embeddings` (`int`, *optional*): | |
| Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during | |
| pretraining. | |
| `attention_factor` (`float`, *optional*): | |
| Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention | |
| computation. If unspecified, it defaults to value recommended by the implementation, using the | |
| `factor` field to infer the suggested value. | |
| `beta_fast` (`float`, *optional*): | |
| Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear | |
| ramp function. If unspecified, it defaults to 32. | |
| `beta_slow` (`float`, *optional*): | |
| Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear | |
| ramp function. If unspecified, it defaults to 1. | |
| `short_factor` (`List[float]`, *optional*): | |
| Only used with 'longrope'. The scaling factor to be applied to short contexts (< | |
| `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden | |
| size divided by the number of attention heads divided by 2 | |
| `long_factor` (`List[float]`, *optional*): | |
| Only used with 'longrope'. The scaling factor to be applied to long contexts (< | |
| `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden | |
| size divided by the number of attention heads divided by 2 | |
| `low_freq_factor` (`float`, *optional*): | |
| Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE | |
| `high_freq_factor` (`float`, *optional*): | |
| Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE | |
| use_sliding_window (`bool`, *optional*, defaults to `False`): | |
| Whether to use sliding window attention. | |
| sliding_window (`int`, *optional*, defaults to 4096): | |
| Sliding window attention (SWA) window size. If not specified, will default to `4096`. | |
| max_window_layers (`int`, *optional*, defaults to 28): | |
| The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. | |
| attention_dropout (`float`, *optional*, defaults to 0.0): | |
| The dropout ratio for the attention probabilities. | |
| ```python | |
| >>> from transformers import Qwen2Model, Qwen2Config | |
| >>> # Initializing a Qwen2 style configuration | |
| >>> configuration = Qwen2Config() | |
| >>> # Initializing a model from the Qwen2-7B style configuration | |
| >>> model = Qwen2Model(configuration) | |
| >>> # Accessing the model configuration | |
| >>> configuration = model.config | |
| ```""" | |
| model_type = "qwen2" | |
| keys_to_ignore_at_inference = ["past_key_values"] | |
| def __init__( | |
| self, | |
| vocab_size=151936, | |
| hidden_size=4096, | |
| intermediate_size=22016, | |
| num_hidden_layers=32, | |
| num_attention_heads=32, | |
| num_key_value_heads=32, | |
| hidden_act="silu", | |
| max_position_embeddings=32768, | |
| initializer_range=0.02, | |
| rms_norm_eps=1e-6, | |
| use_cache=True, | |
| tie_word_embeddings=False, | |
| rope_theta=10000.0, | |
| rope_scaling=None, | |
| use_sliding_window=False, | |
| sliding_window=4096, | |
| max_window_layers=28, | |
| attention_dropout=0.0, | |
| is_causal=True, | |
| _attn_implementation="flash_attention_2", | |
| qk_norm=True, | |
| layer_module="Qwen2DecoderLayer", | |
| freeze_und=False, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| vocab_size=vocab_size, | |
| hidden_size=hidden_size, | |
| intermediate_size=intermediate_size, | |
| num_hidden_layers=num_hidden_layers, | |
| num_attention_heads=num_attention_heads, | |
| num_key_value_heads=num_key_value_heads, | |
| hidden_act=hidden_act, | |
| max_position_embeddings=max_position_embeddings, | |
| initializer_range=initializer_range, | |
| rms_norm_eps=rms_norm_eps, | |
| use_cache=use_cache, | |
| tie_word_embeddings=tie_word_embeddings, | |
| rope_theta=rope_theta, | |
| rope_scaling=rope_scaling, | |
| use_sliding_window=use_sliding_window, | |
| sliding_window=sliding_window, | |
| max_window_layers=max_window_layers, | |
| attention_dropout=attention_dropout, | |
| is_causal=is_causal, | |
| _attn_implementation=_attn_implementation, | |
| **kwargs, | |
| ) | |
| self.qk_norm = qk_norm | |
| self.layer_module = layer_module | |
| self.freeze_und = freeze_und | |
| class NaiveCache: | |
| def __init__(self, num_layers): | |
| self.key_cache = {k: None for k in range(num_layers)} | |
| self.value_cache = {k: None for k in range(num_layers)} | |
| def num_layers(self): | |
| return len(self.key_cache) | |
| def seq_lens(self): | |
| if self.key_cache[0] is not None: | |
| return self.key_cache[0].shape[0] | |
| else: | |
| return 0 | |
| class BaseNavitOutputWithPast(ModelOutput): | |
| packed_query_sequence: torch.FloatTensor = None | |
| past_key_values: Optional[NaiveCache] = None | |
| def pad_sequence(tensor, pad_size): | |
| H, L, D = tensor.shape | |
| pad_tensor = tensor.new_zeros((H, pad_size, D)) | |
| return torch.cat([tensor, pad_tensor], dim=1) | |
| class PackedAttention(Qwen2Attention): | |
| def __init__(self, config, layer_idx: Optional[int] = None): | |
| super().__init__(config, layer_idx) | |
| if self.config.qk_norm: | |
| self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| else: | |
| self.q_norm = nn.Identity() | |
| self.k_norm = nn.Identity() | |
| def forward(self, *args, **kwargs): | |
| if self.training: | |
| return self.forward_train(*args, **kwargs) | |
| else: | |
| return self.forward_inference(*args, **kwargs) | |
| def forward_train( | |
| self, | |
| packed_sequence: torch.Tensor, | |
| sample_lens: List[int], | |
| attention_mask: List[torch.Tensor], | |
| packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], | |
| ): | |
| packed_query_states = self.q_proj(packed_sequence).view(-1, self.num_heads, self.head_dim) | |
| packed_key_states = self.k_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim) | |
| packed_value_states = self.v_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim) | |
| packed_query_states = self.q_norm(packed_query_states) | |
| packed_key_states = self.k_norm(packed_key_states) | |
| packed_cos, packed_sin = packed_position_embeddings | |
| packed_query_states, packed_key_states = apply_rotary_pos_emb( | |
| packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 | |
| ) | |
| if isinstance(attention_mask, List): | |
| packed_key_states = packed_key_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) | |
| packed_key_states = packed_key_states.reshape(-1, self.num_heads, self.head_dim) | |
| packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) | |
| packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim) | |
| unpacked_query_states = packed_query_states.transpose(0, 1).split(sample_lens, dim=1) | |
| unpacked_key_states = packed_key_states.transpose(0, 1).split(sample_lens, dim=1) | |
| unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1) | |
| upacked_attn_output = [] | |
| for query_states, key_states, value_states, attention_mask_per_sample in zip( | |
| unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask | |
| ): | |
| with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): | |
| attn_output = scaled_dot_product_attention( | |
| query_states.to(torch.bfloat16).unsqueeze(0), | |
| key_states.to(torch.bfloat16).unsqueeze(0), | |
| value_states.to(torch.bfloat16).unsqueeze(0), | |
| attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0), | |
| ) | |
| upacked_attn_output.append(attn_output.squeeze(0)) | |
| packed_attn_output = torch.cat(upacked_attn_output, dim=1) | |
| else: | |
| pad_size = sum(sample_lens) - packed_query_states.shape[0] | |
| packed_query_states = pad_sequence(packed_query_states.permute(1, 0, 2), pad_size) | |
| packed_key_states = pad_sequence(packed_key_states.permute(1, 0, 2), pad_size) | |
| packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size) | |
| packed_attn_output = flex_attention( | |
| packed_query_states.unsqueeze(0), | |
| packed_key_states.unsqueeze(0), | |
| packed_value_states.unsqueeze(0), | |
| enable_gqa=True, | |
| block_mask=attention_mask, | |
| ) | |
| end_index = packed_attn_output.shape[2] - pad_size | |
| packed_attn_output = packed_attn_output[0, :, :end_index, :] | |
| packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.hidden_size) | |
| packed_attn_output = self.o_proj(packed_attn_output) | |
| return packed_attn_output | |
| def forward_inference( | |
| self, | |
| packed_query_sequence: torch.Tensor, | |
| query_lens: torch.Tensor, | |
| packed_query_position_embeddings: torch.Tensor, | |
| packed_query_indexes: torch.Tensor, | |
| past_key_values: Optional[NaiveCache] = None, | |
| key_values_lens: Optional[torch.Tensor] = None, | |
| packed_key_value_indexes: Optional[torch.Tensor] = None, | |
| update_past_key_values=True, | |
| is_causal=True, | |
| ): | |
| packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) | |
| packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) | |
| packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) | |
| packed_query_states = self.q_norm(packed_query_states) | |
| packed_key_states = self.k_norm(packed_key_states) | |
| packed_cos, packed_sin = packed_query_position_embeddings | |
| packed_query_states, packed_key_states = apply_rotary_pos_emb( | |
| packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 | |
| ) | |
| packed_query_states = packed_query_states.to(torch.bfloat16) | |
| packed_key_states = packed_key_states.to(torch.bfloat16) | |
| packed_value_states = packed_value_states.to(torch.bfloat16) | |
| if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: | |
| past_key_states = past_key_values.key_cache[self.layer_idx] | |
| past_value_states = past_key_values.value_cache[self.layer_idx] | |
| seqlens = sum(query_lens) + sum(key_values_lens) | |
| merged_key_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim)) | |
| merged_value_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim)) | |
| merged_key_states[packed_query_indexes] = packed_key_states | |
| merged_key_states[packed_key_value_indexes] = past_key_states | |
| merged_value_states[packed_query_indexes] = packed_value_states | |
| merged_value_states[packed_key_value_indexes] = past_value_states | |
| key_values_lens = key_values_lens + query_lens | |
| else: | |
| merged_key_states = packed_key_states | |
| merged_value_states = packed_value_states | |
| key_values_lens = query_lens | |
| cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) | |
| cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) | |
| packed_attn_output = flash_attn_varlen_func( | |
| q=packed_query_states, | |
| k=merged_key_states, | |
| v=merged_value_states, | |
| cu_seqlens_q=cu_seqlens_q.to(torch.int32), | |
| cu_seqlens_k=cu_seqlens_k.to(torch.int32), | |
| max_seqlen_q=max(query_lens).item(), | |
| max_seqlen_k=max(key_values_lens).item(), | |
| causal=is_causal, | |
| ) | |
| packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) | |
| packed_attn_output = self.o_proj(packed_attn_output) | |
| if update_past_key_values: | |
| past_key_values.key_cache[self.layer_idx] = merged_key_states | |
| past_key_values.value_cache[self.layer_idx] = merged_value_states | |
| return packed_attn_output, past_key_values | |
| class PackedAttentionMoT(Qwen2Attention): | |
| def __init__(self, config, layer_idx: Optional[int] = None): | |
| super().__init__(config, layer_idx) | |
| if self.config.qk_norm: | |
| self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| else: | |
| self.q_norm = nn.Identity() | |
| self.k_norm = nn.Identity() | |
| self.q_norm_moe_gen = nn.Identity() | |
| self.k_norm_moe_gen = nn.Identity() | |
| self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) | |
| self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) | |
| self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) | |
| self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) | |
| def forward(self, *args, **kwargs): | |
| if self.training: | |
| return self.forward_train(*args, **kwargs) | |
| else: | |
| return self.forward_inference(*args, **kwargs) | |
| def forward_train( | |
| self, | |
| packed_sequence: torch.Tensor, | |
| sample_lens: List[int], | |
| attention_mask, | |
| packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], | |
| packed_und_token_indexes: torch.LongTensor, | |
| packed_gen_token_indexes: torch.LongTensor, | |
| ): | |
| packed_query_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_heads * self.head_dim)) | |
| packed_key_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)) | |
| packed_value_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)) | |
| packed_sequence_und = packed_sequence[packed_und_token_indexes] | |
| packed_sequence_gen = packed_sequence[packed_gen_token_indexes] | |
| packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und) | |
| packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(packed_sequence_gen) | |
| packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und) | |
| packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(packed_sequence_gen) | |
| packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und) | |
| packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(packed_sequence_gen) | |
| packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) | |
| packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) | |
| packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) | |
| if self.config.freeze_und: | |
| packed_value_states[packed_und_token_indexes] = packed_value_states[packed_und_token_indexes].detach() | |
| packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape) | |
| packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape) | |
| packed_query_states_[packed_und_token_indexes] = self.q_norm(packed_query_states[packed_und_token_indexes]) | |
| if self.config.freeze_und: | |
| packed_query_states_[packed_und_token_indexes] = packed_query_states_[packed_und_token_indexes].detach() | |
| packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_gen_token_indexes]) | |
| packed_key_states_[packed_und_token_indexes] = self.k_norm(packed_key_states[packed_und_token_indexes]) | |
| if self.config.freeze_und: | |
| packed_key_states_[packed_und_token_indexes] = packed_key_states_[packed_und_token_indexes].detach() | |
| packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_gen_token_indexes]) | |
| packed_cos, packed_sin = packed_position_embeddings | |
| packed_query_states_, packed_key_states_ = apply_rotary_pos_emb( | |
| packed_query_states_, packed_key_states_, packed_cos, packed_sin, unsqueeze_dim=1 | |
| ) | |
| if isinstance(attention_mask, List): | |
| packed_key_states_ = packed_key_states_[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) | |
| packed_key_states_ = packed_key_states_.reshape(-1, self.num_heads, self.head_dim) | |
| packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1) | |
| packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim) | |
| unpacked_query_states = packed_query_states_.transpose(0, 1).split(sample_lens, dim=1) | |
| unpacked_key_states = packed_key_states_.transpose(0, 1).split(sample_lens, dim=1) | |
| unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1) | |
| upacked_attn_output = [] | |
| for query_states, key_states, value_states, attention_mask_per_sample in zip( | |
| unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask | |
| ): | |
| with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): | |
| attn_output = scaled_dot_product_attention( | |
| query_states.to(torch.bfloat16).unsqueeze(0), | |
| key_states.to(torch.bfloat16).unsqueeze(0), | |
| value_states.to(torch.bfloat16).unsqueeze(0), | |
| attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0), | |
| ) | |
| upacked_attn_output.append(attn_output.squeeze(0)) | |
| packed_attn_output = torch.cat(upacked_attn_output, dim=1) | |
| else: | |
| pad_size = sum(sample_lens) - packed_query_states.shape[0] | |
| packed_query_states_ = pad_sequence(packed_query_states_.permute(1, 0, 2), pad_size) | |
| packed_key_states_ = pad_sequence(packed_key_states_.permute(1, 0, 2), pad_size) | |
| packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size) | |
| packed_attn_output = flex_attention( | |
| packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim | |
| packed_key_states_.unsqueeze(0), | |
| packed_value_states.unsqueeze(0), | |
| enable_gqa=True, | |
| block_mask=attention_mask, | |
| ) | |
| end_index = packed_attn_output.shape[2] - pad_size | |
| packed_attn_output = packed_attn_output[0, :, :end_index, :] | |
| packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.num_heads * self.head_dim) | |
| packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape) | |
| packed_attn_output_[packed_und_token_indexes] = self.o_proj(packed_attn_output[packed_und_token_indexes]) | |
| packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_gen_token_indexes]) | |
| return packed_attn_output_ | |
| def forward_inference( | |
| self, | |
| packed_query_sequence: torch.Tensor, | |
| query_lens: torch.Tensor, | |
| packed_query_position_embeddings: torch.Tensor, | |
| packed_query_indexes: torch.Tensor, | |
| past_key_values: Optional[NaiveCache] = None, | |
| key_values_lens: Optional[torch.Tensor] = None, | |
| packed_key_value_indexes: Optional[torch.Tensor] = None, | |
| update_past_key_values=True, | |
| is_causal=True, | |
| mode="und", | |
| packed_vae_token_indexes=None, | |
| packed_text_indexes=None, | |
| ): | |
| if mode == 'und': | |
| packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) | |
| packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) | |
| packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) | |
| packed_query_states = self.q_norm(packed_query_states) | |
| packed_key_states = self.k_norm(packed_key_states) | |
| elif mode == 'gen': | |
| packed_query_sequence = packed_query_sequence.to(torch.bfloat16) | |
| packed_query_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_heads * self.head_dim)) | |
| packed_key_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim)) | |
| packed_value_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim)) | |
| packed_text_query_sequence = packed_query_sequence[packed_text_indexes] | |
| packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] | |
| packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence) | |
| packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence) | |
| packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence) | |
| packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence) | |
| packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence) | |
| packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence) | |
| packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) | |
| packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) | |
| packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) | |
| packed_query_states = packed_query_states.to(torch.float32) | |
| packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes]) | |
| packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_vae_token_indexes]) | |
| packed_key_states = packed_key_states.to(torch.float32) | |
| packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes]) | |
| packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_vae_token_indexes]) | |
| packed_cos, packed_sin = packed_query_position_embeddings | |
| packed_query_states, packed_key_states = apply_rotary_pos_emb( | |
| packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1 | |
| ) | |
| packed_query_states = packed_query_states.to(torch.bfloat16) | |
| packed_key_states = packed_key_states.to(torch.bfloat16) | |
| packed_value_states = packed_value_states.to(torch.bfloat16) | |
| if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: | |
| past_key_states = past_key_values.key_cache[self.layer_idx] | |
| past_value_states = past_key_values.value_cache[self.layer_idx] | |
| seqlens = sum(query_lens) + sum(key_values_lens) | |
| merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) | |
| merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) | |
| merged_key_states[packed_query_indexes] = packed_key_states | |
| merged_key_states[packed_key_value_indexes] = past_key_states | |
| merged_value_states[packed_query_indexes] = packed_value_states | |
| merged_value_states[packed_key_value_indexes] = past_value_states | |
| key_values_lens = key_values_lens + query_lens | |
| else: | |
| merged_key_states = packed_key_states | |
| merged_value_states = packed_value_states | |
| key_values_lens = query_lens | |
| cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) | |
| cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) | |
| packed_attn_output = flash_attn_varlen_func( | |
| q=packed_query_states, | |
| k=merged_key_states, | |
| v=merged_value_states, | |
| cu_seqlens_q=cu_seqlens_q.to(torch.int32), | |
| cu_seqlens_k=cu_seqlens_k.to(torch.int32), | |
| max_seqlen_q=max(query_lens).item(), | |
| max_seqlen_k=max(key_values_lens).item(), | |
| causal=is_causal, | |
| ) | |
| packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) | |
| if mode == 'und': | |
| packed_attn_output = self.o_proj(packed_attn_output) | |
| elif mode == 'gen': | |
| packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes]) | |
| packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes]) | |
| if update_past_key_values: | |
| past_key_values.key_cache[self.layer_idx] = merged_key_states | |
| past_key_values.value_cache[self.layer_idx] = merged_value_states | |
| return packed_attn_output, past_key_values | |
| class Qwen2DecoderLayer(nn.Module): | |
| def __init__(self, config, layer_idx: Optional[int] = None): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.self_attn = PackedAttention(config, layer_idx) | |
| self.mlp = Qwen2MLP(config) | |
| self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def forward(self, *args, **kwargs): | |
| if self.training: | |
| return self.forward_train(*args, **kwargs) | |
| else: | |
| return self.forward_inference(*args, **kwargs) | |
| def forward_train( | |
| self, | |
| packed_sequence: torch.Tensor, | |
| sample_lens: List[int], | |
| attention_mask, | |
| packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], | |
| ) -> torch.Tensor: | |
| residual = packed_sequence | |
| packed_sequence = self.input_layernorm(packed_sequence) | |
| # Self Attention | |
| packed_sequence = self.self_attn( | |
| packed_sequence=packed_sequence, | |
| sample_lens=sample_lens, | |
| attention_mask=attention_mask, | |
| packed_position_embeddings=packed_position_embeddings, | |
| ) | |
| packed_sequence = residual + packed_sequence | |
| # Fully Connected | |
| residual = packed_sequence | |
| packed_sequence = self.post_attention_layernorm(packed_sequence) | |
| packed_sequence = self.mlp(packed_sequence) | |
| packed_sequence = residual + packed_sequence | |
| return packed_sequence | |
| def forward_inference( | |
| self, | |
| packed_query_sequence: torch.Tensor, | |
| query_lens: torch.Tensor, | |
| packed_query_position_embeddings: torch.Tensor, | |
| packed_query_indexes: torch.Tensor, | |
| past_key_values: Optional[NaiveCache] = None, | |
| key_values_lens: Optional[torch.Tensor] = None, | |
| packed_key_value_indexes: Optional[torch.Tensor] = None, | |
| update_past_key_values=True, | |
| is_causal=True, | |
| ) -> BaseNavitOutputWithPast: | |
| residual = packed_query_sequence | |
| packed_query_sequence = self.input_layernorm(packed_query_sequence) | |
| # Self Attention | |
| packed_query_sequence, past_key_values = self.self_attn( | |
| packed_query_sequence=packed_query_sequence, | |
| query_lens=query_lens, | |
| packed_query_position_embeddings=packed_query_position_embeddings, | |
| packed_query_indexes=packed_query_indexes, | |
| past_key_values=past_key_values, | |
| key_values_lens=key_values_lens, | |
| packed_key_value_indexes=packed_key_value_indexes, | |
| update_past_key_values=update_past_key_values, | |
| is_causal=is_causal, | |
| ) | |
| packed_query_sequence = residual + packed_query_sequence | |
| # Fully Connected | |
| residual = packed_query_sequence | |
| packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) | |
| packed_query_sequence = self.mlp(packed_query_sequence) | |
| packed_query_sequence = residual + packed_query_sequence | |
| return packed_query_sequence, past_key_values | |
| class Qwen2MoTDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| config, | |
| layer_idx: Optional[int] = None, | |
| attn_module: Optional[Qwen2Attention] = PackedAttentionMoT, | |
| ): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.freeze_und = config.freeze_und | |
| self.self_attn = attn_module(config, layer_idx) | |
| self.mlp = Qwen2MLP(config) | |
| self.mlp_moe_gen = Qwen2MLP(config) | |
| self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def forward(self, *args, **kwargs): | |
| if self.training: | |
| return self.forward_train(*args, **kwargs) | |
| else: | |
| return self.forward_inference(*args, **kwargs) | |
| def forward_train( | |
| self, | |
| packed_sequence: torch.Tensor, | |
| sample_lens: List[int], | |
| attention_mask, | |
| packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], | |
| packed_und_token_indexes: torch.LongTensor, | |
| packed_gen_token_indexes: torch.LongTensor, | |
| ) -> torch.Tensor: | |
| residual = packed_sequence | |
| packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape) | |
| packed_sequence_[packed_und_token_indexes] = self.input_layernorm(packed_sequence[packed_und_token_indexes]) | |
| packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes]) | |
| # Self Attention | |
| packed_sequence_ = self.self_attn( | |
| packed_sequence=packed_sequence_, | |
| sample_lens=sample_lens, | |
| attention_mask=attention_mask, | |
| packed_position_embeddings=packed_position_embeddings, | |
| packed_und_token_indexes=packed_und_token_indexes, | |
| packed_gen_token_indexes=packed_gen_token_indexes, | |
| ) | |
| if self.freeze_und: | |
| packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() | |
| packed_sequence = residual + packed_sequence_ | |
| # Fully Connected | |
| residual = packed_sequence | |
| packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape) | |
| packed_sequence_[packed_und_token_indexes] = self.mlp( | |
| self.post_attention_layernorm(packed_sequence[packed_und_token_indexes]) | |
| ) | |
| if self.freeze_und: | |
| packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() | |
| packed_sequence_[packed_gen_token_indexes] = self.mlp_moe_gen( | |
| self.post_attention_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes]) | |
| ) | |
| packed_sequence = residual + packed_sequence_ | |
| return packed_sequence | |
| def forward_inference( | |
| self, | |
| packed_query_sequence: torch.Tensor, | |
| query_lens: torch.Tensor, | |
| packed_query_position_embeddings: torch.Tensor, | |
| packed_query_indexes: torch.Tensor, | |
| past_key_values: Optional[NaiveCache] = None, | |
| key_values_lens: Optional[torch.Tensor] = None, | |
| packed_key_value_indexes: Optional[torch.Tensor] = None, | |
| update_past_key_values=True, | |
| is_causal=True, | |
| mode="und", | |
| packed_vae_token_indexes=None, | |
| packed_text_indexes=None, | |
| ) -> BaseNavitOutputWithPast: | |
| residual = packed_query_sequence | |
| if mode == "und": | |
| packed_query_sequence = self.input_layernorm(packed_query_sequence) | |
| elif mode == "gen": | |
| packed_query_sequence_ = torch.zeros_like(packed_query_sequence) | |
| packed_query_sequence_[packed_text_indexes] = self.input_layernorm(packed_query_sequence[packed_text_indexes]) | |
| packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen(packed_query_sequence[packed_vae_token_indexes]) | |
| packed_query_sequence = packed_query_sequence_ | |
| # Self Attention | |
| packed_query_sequence, past_key_values = self.self_attn( | |
| packed_query_sequence=packed_query_sequence, | |
| query_lens=query_lens, | |
| packed_query_position_embeddings=packed_query_position_embeddings, | |
| packed_query_indexes=packed_query_indexes, | |
| past_key_values=past_key_values, | |
| key_values_lens=key_values_lens, | |
| packed_key_value_indexes=packed_key_value_indexes, | |
| update_past_key_values=update_past_key_values, | |
| is_causal=is_causal, | |
| mode=mode, | |
| packed_vae_token_indexes=packed_vae_token_indexes, | |
| packed_text_indexes=packed_text_indexes, | |
| ) | |
| packed_query_sequence = residual + packed_query_sequence | |
| # Fully Connected | |
| residual = packed_query_sequence | |
| if mode == "und": | |
| packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) | |
| packed_query_sequence = self.mlp(packed_query_sequence) | |
| elif mode == "gen": | |
| packed_text_query_sequence = packed_query_sequence[packed_text_indexes] | |
| packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] | |
| packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16) | |
| packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to(torch.bfloat16) | |
| packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) | |
| packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence) | |
| packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence) | |
| packed_query_sequence = packed_query_sequence_ | |
| packed_query_sequence = residual + packed_query_sequence | |
| return packed_query_sequence, past_key_values | |
| class Qwen2MoEDecoderLayer(nn.Module): | |
| def __init__(self, config, layer_idx: Optional[int] = None): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.self_attn = PackedAttention(config, layer_idx) | |
| self.mlp = Qwen2MLP(config) | |
| self.mlp_moe_gen = Qwen2MLP(config) | |
| self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def forward(self, *args, **kwargs): | |
| if self.training: | |
| return self.forward_train(*args, **kwargs) | |
| else: | |
| return self.forward_inference(*args, **kwargs) | |
| def forward_train( | |
| self, | |
| packed_sequence: torch.Tensor, | |
| sample_lens: List[int], | |
| attention_mask, | |
| packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor], | |
| packed_und_token_indexes: torch.LongTensor, | |
| packed_gen_token_indexes: torch.LongTensor, | |
| ) -> torch.Tensor: | |
| residual = packed_sequence | |
| packed_sequence = self.input_layernorm(packed_sequence) | |
| # Self Attention | |
| packed_sequence = self.self_attn( | |
| packed_sequence=packed_sequence, | |
| sample_lens=sample_lens, | |
| attention_mask=attention_mask, | |
| packed_position_embeddings=packed_position_embeddings, | |
| ) | |
| packed_sequence = residual + packed_sequence | |
| # Fully Connected | |
| residual = packed_sequence | |
| packed_sequence = self.post_attention_layernorm(packed_sequence) | |
| packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape) | |
| packed_sequence_und = self.mlp(packed_sequence[packed_und_token_indexes]) | |
| packed_sequence_gen = self.mlp_moe_gen(packed_sequence[packed_gen_token_indexes]) | |
| packed_sequence_new[packed_und_token_indexes] = packed_sequence_und | |
| packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen | |
| packed_sequence = residual + packed_sequence_new | |
| return packed_sequence | |
| def forward_inference( | |
| self, | |
| packed_query_sequence: torch.Tensor, | |
| query_lens: torch.Tensor, | |
| packed_query_position_embeddings: torch.Tensor, | |
| packed_query_indexes: torch.Tensor, | |
| past_key_values: Optional[NaiveCache] = None, | |
| key_values_lens: Optional[torch.Tensor] = None, | |
| packed_key_value_indexes: Optional[torch.Tensor] = None, | |
| update_past_key_values=True, | |
| is_causal=True, | |
| mode="und", | |
| packed_vae_token_indexes=None, | |
| packed_text_indexes=None, | |
| ) -> BaseNavitOutputWithPast: | |
| residual = packed_query_sequence | |
| packed_query_sequence = self.input_layernorm(packed_query_sequence) | |
| # Self Attention | |
| packed_query_sequence, past_key_values = self.self_attn( | |
| packed_query_sequence=packed_query_sequence, | |
| query_lens=query_lens, | |
| packed_query_position_embeddings=packed_query_position_embeddings, | |
| packed_query_indexes=packed_query_indexes, | |
| past_key_values=past_key_values, | |
| key_values_lens=key_values_lens, | |
| packed_key_value_indexes=packed_key_value_indexes, | |
| update_past_key_values=update_past_key_values, | |
| is_causal=is_causal, | |
| ) | |
| packed_query_sequence = residual + packed_query_sequence | |
| # Fully Connected | |
| residual = packed_query_sequence | |
| packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) | |
| if mode == "und": | |
| packed_query_sequence = self.mlp(packed_query_sequence) | |
| elif mode == "gen": | |
| packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) | |
| packed_query_sequence_[packed_text_indexes] = self.mlp(packed_query_sequence[packed_text_indexes]) | |
| packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_query_sequence[packed_vae_token_indexes]) | |
| packed_query_sequence = packed_query_sequence_ | |
| packed_query_sequence = residual + packed_query_sequence | |
| return packed_query_sequence, past_key_values | |
| Decoder_layer_dict = { | |
| "Qwen2DecoderLayer": Qwen2DecoderLayer, | |
| "Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer, | |
| "Qwen2MoTDecoderLayer": partial(Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT), | |
| } | |
| class Qwen2Model(Qwen2PreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.padding_idx = config.pad_token_id | |
| self.vocab_size = config.vocab_size | |
| self.use_moe = 'Mo' in config.layer_module | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | |
| layer_module = Decoder_layer_dict[config.layer_module] | |
| self.layers = nn.ModuleList( | |
| [layer_module(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | |
| ) | |
| self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| if self.use_moe: | |
| self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.rotary_emb = Qwen2RotaryEmbedding(config=config) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def forward(self, *args, **kwargs): | |
| if self.training: | |
| return self.forward_train(*args, **kwargs) | |
| else: | |
| return self.forward_inference(*args, **kwargs) | |
| def forward_train( | |
| self, | |
| packed_sequence: torch.Tensor, | |
| sample_lens: List[int], | |
| attention_mask, | |
| packed_position_ids: torch.Tensor, | |
| packed_und_token_indexes: Optional[torch.LongTensor] = None, | |
| packed_gen_token_indexes: Optional[torch.LongTensor] = None, | |
| ) -> torch.Tensor: | |
| if self.config.freeze_und: | |
| packed_sequence[packed_und_token_indexes] = packed_sequence[packed_und_token_indexes].detach() | |
| # create position embeddings to be shared across the decoder layers | |
| cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0)) | |
| cos = cos.squeeze(0) | |
| sin = sin.squeeze(0) | |
| packed_position_embeddings = (cos, sin) | |
| extra_inputs = {} | |
| if self.use_moe: | |
| assert packed_und_token_indexes is not None | |
| if packed_gen_token_indexes is None: | |
| packed_gen_token_indexes = packed_und_token_indexes.new_ones(size=[0]) | |
| extra_inputs.update( | |
| packed_und_token_indexes=packed_und_token_indexes, | |
| packed_gen_token_indexes=packed_gen_token_indexes, | |
| ) | |
| for decoder_layer in self.layers: | |
| packed_sequence = decoder_layer( | |
| packed_sequence=packed_sequence, | |
| sample_lens=sample_lens, | |
| attention_mask=attention_mask, | |
| packed_position_embeddings=packed_position_embeddings, | |
| **extra_inputs | |
| ) | |
| if self.use_moe: | |
| packed_sequence_ = torch.zeros_like(packed_sequence) | |
| packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes]) | |
| if self.config.freeze_und: | |
| packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach() | |
| packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(packed_sequence[packed_gen_token_indexes]) | |
| return packed_sequence_ | |
| else: | |
| return self.norm(packed_sequence) | |
| def forward_inference( | |
| self, | |
| packed_query_sequence: torch.Tensor, | |
| query_lens: torch.Tensor, | |
| packed_query_position_ids: torch.Tensor, | |
| packed_query_indexes: torch.Tensor, | |
| past_key_values: Optional[NaiveCache] = None, | |
| key_values_lens: Optional[torch.Tensor] = None, | |
| packed_key_value_indexes: Optional[torch.Tensor] = None, | |
| update_past_key_values=True, | |
| is_causal=True, | |
| mode="und", | |
| packed_vae_token_indexes=None, | |
| packed_text_indexes=None, | |
| ) -> BaseNavitOutputWithPast: | |
| # create position embeddings to be shared across the decoder layers | |
| cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0)) | |
| cos = cos.squeeze(0) | |
| sin = sin.squeeze(0) | |
| packed_query_position_embeddings = (cos, sin) | |
| extra_inputs = {} | |
| if self.use_moe: | |
| extra_inputs.update(mode=mode) | |
| if mode == 'gen': | |
| assert packed_vae_token_indexes is not None | |
| assert packed_text_indexes is not None | |
| extra_inputs.update( | |
| packed_vae_token_indexes=packed_vae_token_indexes, | |
| packed_text_indexes=packed_text_indexes, | |
| ) | |
| for decoder_layer in self.layers: | |
| packed_query_sequence, past_key_values = decoder_layer( | |
| packed_query_sequence=packed_query_sequence, | |
| query_lens=query_lens, | |
| packed_query_position_embeddings=packed_query_position_embeddings, | |
| packed_query_indexes=packed_query_indexes, | |
| past_key_values=past_key_values, | |
| key_values_lens=key_values_lens, | |
| packed_key_value_indexes=packed_key_value_indexes, | |
| update_past_key_values=update_past_key_values, | |
| is_causal=is_causal, | |
| **extra_inputs, | |
| ) | |
| if self.use_moe: | |
| if mode == "und": | |
| packed_query_sequence = self.norm(packed_query_sequence) | |
| elif mode == "gen": | |
| packed_query_sequence_ = torch.zeros_like(packed_query_sequence) | |
| packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes]) | |
| packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(packed_query_sequence[packed_vae_token_indexes]) | |
| packed_query_sequence = packed_query_sequence_ | |
| else: | |
| packed_query_sequence = self.norm(packed_query_sequence) | |
| return BaseNavitOutputWithPast( | |
| packed_query_sequence=packed_query_sequence, | |
| past_key_values=past_key_values, | |
| ) | |
| class Qwen2ForCausalLM(Qwen2PreTrainedModel): | |
| _tied_weights_keys = ["lm_head.weight"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = Qwen2Model(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def init_moe(self): | |
| for name, param in self.named_parameters(): | |
| if "moe_gen" in name: | |
| original_name = name.replace("_moe_gen", "") | |
| param.data.copy_(self.state_dict()[original_name].data) | |
| def get_input_embeddings(self): | |
| return self.model.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.model.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def set_decoder(self, decoder): | |
| self.model = decoder | |
| def get_decoder(self): | |
| return self.model | |
| def forward(self, *args, **kwargs): | |
| if self.training: | |
| return self.forward_train(*args, **kwargs) | |
| else: | |
| return self.forward_inference(*args, **kwargs) | |
| def forward_train( | |
| self, | |
| packed_sequence: torch.Tensor, | |
| sample_lens: List[int], | |
| attention_mask, | |
| packed_position_ids: torch.Tensor, | |
| packed_und_token_indexes: Optional[torch.LongTensor] = None, | |
| packed_gen_token_indexes: Optional[torch.LongTensor] = None, | |
| ) -> torch.Tensor: | |
| outputs = self.model( | |
| packed_sequence=packed_sequence, | |
| sample_lens=sample_lens, | |
| packed_position_ids=packed_position_ids, | |
| attention_mask=attention_mask, | |
| packed_und_token_indexes=packed_und_token_indexes, | |
| packed_gen_token_indexes=packed_gen_token_indexes, | |
| ) | |
| return outputs | |
| def forward_inference( | |
| self, | |
| packed_query_sequence: torch.Tensor, | |
| query_lens: torch.Tensor, | |
| packed_query_position_ids: torch.Tensor, | |
| packed_query_indexes: torch.Tensor, | |
| past_key_values: Optional[NaiveCache] = None, | |
| key_values_lens: Optional[torch.Tensor] = None, | |
| packed_key_value_indexes: Optional[torch.Tensor] = None, | |
| update_past_key_values=True, | |
| is_causal=True, | |
| mode="und", | |
| packed_vae_token_indexes=None, | |
| packed_text_indexes=None, | |
| ) -> BaseNavitOutputWithPast: | |
| outputs = self.model( | |
| packed_query_sequence=packed_query_sequence, | |
| query_lens=query_lens, | |
| packed_query_position_ids=packed_query_position_ids, | |
| packed_query_indexes=packed_query_indexes, | |
| past_key_values=past_key_values, | |
| key_values_lens=key_values_lens, | |
| packed_key_value_indexes=packed_key_value_indexes, | |
| update_past_key_values=update_past_key_values, | |
| is_causal=is_causal, | |
| mode=mode, | |
| packed_vae_token_indexes=packed_vae_token_indexes, | |
| packed_text_indexes=packed_text_indexes, | |
| ) | |
| return outputs | |