chengs18 commited on
Commit
9b9a3a3
·
verified ·
1 Parent(s): 40c412d

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,40 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: transformers
4
+ ---
5
+ # Introduction
6
+
7
+ **SDAR**(**S**ynergy of **D**iffusion and **A**uto**R**egression)-model is a new large language model that integrates autoregressive (AR) and discrete diffusion modeling strategies. It combines the efficient training paradigm of AR models with the highly parallel inference capability of diffusion models, while delivering performance fully on par with SOTA opensource AR models. At the same time, SDAR sets a new benchmark as the most powerful diffusion language model to date.
8
+
9
+ ---
10
+
11
+ # performance of SDAR-1.7B-Chat on various benchmarks
12
+
13
+ evaluation settings:
14
+ - MMLU: 5-shot
15
+ - Math500: 0-shot
16
+ - GSM8K: 0-shot
17
+ - HumanEval: 0-shot
18
+ - Sanitized_MBPP: 0-shot
19
+ - IFEval: 0-shot
20
+ - MathBench: 0-shot
21
+
22
+
23
+ | Model | MMLU | Math500 | GSM8K | HumanEval | Sanitized_MBPP | IFEval | MathBench |
24
+ |-------------------|------|---------|-------|-----------|----------------|--------|-----------|
25
+ | SDAR-1.7B-Chat | 62.9 | 63.2 | 80.06 | 61.59 | 61.09 | 43.44 | 63.55 |
26
+ | SDAR-4B-Chat | | | | | | | |
27
+ | SDAR-8B-Chat | | | | | | | |
28
+ | SDAR-30B-A3B-Chat | | | | | | | |
29
+
30
+
31
+
32
+ **Note**: The 4B, 8B, and 30B models are coming soon. Performance results for these models will be released in the near future.
33
+
34
+
35
+ ## Inference
36
+ The inference code will come soon
37
+
38
+ ## Hightlights
39
+ - **Performance**: SDAR-1.7B-Chat achieves state-of-the-art.
40
+ - **Efficiency**: SDAR provides over 2× faster inference speed compared to the same-size AR models, while maintaining comparable performance.
added_tokens.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<MASK>": 151669,
6
+ "<think>": 151667,
7
+ "<tool_call>": 151657,
8
+ "<tool_response>": 151665,
9
+ "<|box_end|>": 151649,
10
+ "<|box_start|>": 151648,
11
+ "<|endoftext|>": 151643,
12
+ "<|file_sep|>": 151664,
13
+ "<|fim_middle|>": 151660,
14
+ "<|fim_pad|>": 151662,
15
+ "<|fim_prefix|>": 151659,
16
+ "<|fim_suffix|>": 151661,
17
+ "<|im_end|>": 151645,
18
+ "<|im_start|>": 151644,
19
+ "<|image_pad|>": 151655,
20
+ "<|object_ref_end|>": 151647,
21
+ "<|object_ref_start|>": 151646,
22
+ "<|quad_end|>": 151651,
23
+ "<|quad_start|>": 151650,
24
+ "<|repo_name|>": 151663,
25
+ "<|video_pad|>": 151656,
26
+ "<|vision_end|>": 151653,
27
+ "<|vision_pad|>": 151654,
28
+ "<|vision_start|>": 151652
29
+ }
chat_template.jinja ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
27
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
28
+ {%- elif message.role == "assistant" %}
29
+ {%- set content = message.content %}
30
+ {%- set reasoning_content = '' %}
31
+ {%- if message.reasoning_content is defined and message.reasoning_content is not none %}
32
+ {%- set reasoning_content = message.reasoning_content %}
33
+ {%- else %}
34
+ {%- if '</think>' in message.content %}
35
+ {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
36
+ {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- endif %}
38
+ {%- endif %}
39
+ {%- if loop.index0 > ns.last_query_index %}
40
+ {%- if loop.last or (not loop.last and reasoning_content) %}
41
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
42
+ {%- else %}
43
+ {{- '<|im_start|>' + message.role + '\n' + content }}
44
+ {%- endif %}
45
+ {%- else %}
46
+ {{- '<|im_start|>' + message.role + '\n' + content }}
47
+ {%- endif %}
48
+ {%- if message.tool_calls %}
49
+ {%- for tool_call in message.tool_calls %}
50
+ {%- if (loop.first and content) or (not loop.first) %}
51
+ {{- '\n' }}
52
+ {%- endif %}
53
+ {%- if tool_call.function %}
54
+ {%- set tool_call = tool_call.function %}
55
+ {%- endif %}
56
+ {{- '<tool_call>\n{"name": "' }}
57
+ {{- tool_call.name }}
58
+ {{- '", "arguments": ' }}
59
+ {%- if tool_call.arguments is string %}
60
+ {{- tool_call.arguments }}
61
+ {%- else %}
62
+ {{- tool_call.arguments | tojson }}
63
+ {%- endif %}
64
+ {{- '}\n</tool_call>' }}
65
+ {%- endfor %}
66
+ {%- endif %}
67
+ {{- '<|im_end|>\n' }}
68
+ {%- elif message.role == "tool" %}
69
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
70
+ {{- '<|im_start|>user' }}
71
+ {%- endif %}
72
+ {{- '\n<tool_response>\n' }}
73
+ {{- message.content }}
74
+ {{- '\n</tool_response>' }}
75
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
76
+ {{- '<|im_end|>\n' }}
77
+ {%- endif %}
78
+ {%- endif %}
79
+ {%- endfor %}
80
+ {%- if add_generation_prompt %}
81
+ {{- '<|im_start|>assistant\n' }}
82
+ {%- if enable_thinking is defined and enable_thinking is false %}
83
+ {{- '<think>\n\n</think>\n\n' }}
84
+ {%- endif %}
85
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_qwen3.Qwen3Config",
7
+ "AutoModel": "modeling_qwen3.Qwen3Model",
8
+ "AutoModelForCausalLM": "modeling_qwen3.Qwen3ForCausalLM"
9
+ },
10
+ "attention_bias": false,
11
+ "attention_dropout": 0.0,
12
+ "bos_token_id": 151643,
13
+ "eos_token_id": 151643,
14
+ "fuse_cross_entropy": true,
15
+ "head_dim": 128,
16
+ "hidden_act": "silu",
17
+ "hidden_size": 2048,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 6144,
20
+ "max_position_embeddings": 32768,
21
+ "max_window_layers": 28,
22
+ "model_type": "qwen3",
23
+ "num_attention_heads": 16,
24
+ "num_hidden_layers": 28,
25
+ "num_key_value_heads": 8,
26
+ "rms_norm_eps": 1e-06,
27
+ "rope_scaling": null,
28
+ "rope_theta": 1000000,
29
+ "sliding_window": null,
30
+ "tie_word_embeddings": false,
31
+ "torch_dtype": "bfloat16",
32
+ "transformers_version": "4.52.3",
33
+ "use_cache": true,
34
+ "use_sliding_window": false,
35
+ "vocab_size": 151936
36
+ }
configuration_qwen3.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen3 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Qwen3Config(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a
28
+ Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of
30
+ Qwen3-8B [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B).
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 151936):
38
+ Vocabulary size of the Qwen3 model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`Qwen3Model`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 22016):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer encoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer encoder.
48
+ num_key_value_heads (`int`, *optional*, defaults to 32):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
55
+ head_dim (`int`, *optional*, defaults to 128):
56
+ The attention head dimension.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
58
+ The non-linear activation function (function or string) in the decoder.
59
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
60
+ The maximum sequence length that this model might ever be used with.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
+ The epsilon used by the rms normalization layers.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
69
+ Whether the model's input and output word embeddings should be tied.
70
+ rope_theta (`float`, *optional*, defaults to 10000.0):
71
+ The base period of the RoPE embeddings.
72
+ rope_scaling (`Dict`, *optional*):
73
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
74
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
75
+ accordingly.
76
+ Expected contents:
77
+ `rope_type` (`str`):
78
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
79
+ 'llama3'], with 'default' being the original RoPE implementation.
80
+ `factor` (`float`, *optional*):
81
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
82
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
83
+ original maximum pre-trained length.
84
+ `original_max_position_embeddings` (`int`, *optional*):
85
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
86
+ pretraining.
87
+ `attention_factor` (`float`, *optional*):
88
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
89
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
90
+ `factor` field to infer the suggested value.
91
+ `beta_fast` (`float`, *optional*):
92
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
93
+ ramp function. If unspecified, it defaults to 32.
94
+ `beta_slow` (`float`, *optional*):
95
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
96
+ ramp function. If unspecified, it defaults to 1.
97
+ `short_factor` (`List[float]`, *optional*):
98
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
99
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
100
+ size divided by the number of attention heads divided by 2
101
+ `long_factor` (`List[float]`, *optional*):
102
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
103
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
104
+ size divided by the number of attention heads divided by 2
105
+ `low_freq_factor` (`float`, *optional*):
106
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
107
+ `high_freq_factor` (`float`, *optional*):
108
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
109
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
110
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
111
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
112
+ Whether to use sliding window attention.
113
+ sliding_window (`int`, *optional*, defaults to 4096):
114
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
115
+ max_window_layers (`int`, *optional*, defaults to 28):
116
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
117
+ attention_dropout (`float`, *optional*, defaults to 0.0):
118
+ The dropout ratio for the attention probabilities.
119
+
120
+ ```python
121
+ >>> from transformers import Qwen3Model, Qwen3Config
122
+
123
+ >>> # Initializing a Qwen3 style configuration
124
+ >>> configuration = Qwen3Config()
125
+
126
+ >>> # Initializing a model from the Qwen3-8B style configuration
127
+ >>> model = Qwen3Model(configuration)
128
+
129
+ >>> # Accessing the model configuration
130
+ >>> configuration = model.config
131
+ ```"""
132
+
133
+ model_type = "qwen3"
134
+ keys_to_ignore_at_inference = ["past_key_values"]
135
+
136
+ # Default tensor parallel plan for base model `Qwen3`
137
+ base_model_tp_plan = {
138
+ "layers.*.self_attn.q_proj": "colwise",
139
+ "layers.*.self_attn.k_proj": "colwise",
140
+ "layers.*.self_attn.v_proj": "colwise",
141
+ "layers.*.self_attn.o_proj": "rowwise",
142
+ "layers.*.mlp.gate_proj": "colwise",
143
+ "layers.*.mlp.up_proj": "colwise",
144
+ "layers.*.mlp.down_proj": "rowwise",
145
+ }
146
+ base_model_pp_plan = {
147
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
148
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
149
+ "norm": (["hidden_states"], ["hidden_states"]),
150
+ }
151
+
152
+ def __init__(
153
+ self,
154
+ vocab_size=151936,
155
+ hidden_size=4096,
156
+ intermediate_size=22016,
157
+ num_hidden_layers=32,
158
+ num_attention_heads=32,
159
+ num_key_value_heads=32,
160
+ head_dim=128,
161
+ hidden_act="silu",
162
+ max_position_embeddings=32768,
163
+ initializer_range=0.02,
164
+ rms_norm_eps=1e-6,
165
+ use_cache=True,
166
+ tie_word_embeddings=False,
167
+ rope_theta=10000.0,
168
+ rope_scaling=None,
169
+ attention_bias=False,
170
+ use_sliding_window=False,
171
+ sliding_window=4096,
172
+ max_window_layers=28,
173
+ attention_dropout=0.0,
174
+ **kwargs,
175
+ ):
176
+ self.vocab_size = vocab_size
177
+ self.max_position_embeddings = max_position_embeddings
178
+ self.hidden_size = hidden_size
179
+ self.intermediate_size = intermediate_size
180
+ self.num_hidden_layers = num_hidden_layers
181
+ self.num_attention_heads = num_attention_heads
182
+ self.use_sliding_window = use_sliding_window
183
+ self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code
184
+ self.max_window_layers = max_window_layers
185
+
186
+ # for backward compatibility
187
+ if num_key_value_heads is None:
188
+ num_key_value_heads = num_attention_heads
189
+
190
+ self.num_key_value_heads = num_key_value_heads
191
+ self.head_dim = head_dim
192
+ self.hidden_act = hidden_act
193
+ self.initializer_range = initializer_range
194
+ self.rms_norm_eps = rms_norm_eps
195
+ self.use_cache = use_cache
196
+ self.rope_theta = rope_theta
197
+ self.rope_scaling = rope_scaling
198
+ self.attention_bias = attention_bias
199
+ self.attention_dropout = attention_dropout
200
+ # Validate the correctness of rotary position embeddings parameters
201
+ # BC: if there is a 'type' field, move it to 'rope_type'.
202
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
203
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
204
+ rope_config_validation(self)
205
+
206
+ super().__init__(
207
+ tie_word_embeddings=tie_word_embeddings,
208
+ **kwargs,
209
+ )
210
+
211
+
212
+ __all__ = ["Qwen3Config"]
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.51.0"
13
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1737775176591d7c7f39b884b98d620d87646f8220b9b6b39431b6f6467e3e0f
3
+ size 4063515640
modeling_qwen3.py ADDED
@@ -0,0 +1,1208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_qwen3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Callable, Optional, Tuple, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
29
+ from transformers.generation import GenerationMixin
30
+ from transformers.integrations import use_kernel_forward_from_hub
31
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
33
+ from transformers.modeling_layers import GradientCheckpointingLayer
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPast,
36
+ CausalLMOutputWithPast,
37
+ QuestionAnsweringModelOutput,
38
+ SequenceClassifierOutputWithPast,
39
+ TokenClassifierOutput,
40
+ )
41
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
+ from transformers.processing_utils import Unpack
44
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
45
+ from .configuration_qwen3 import Qwen3Config
46
+
47
+ from fla.modules.activations import swiglu_linear
48
+ from fla.modules import (
49
+ FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss,
50
+ FusedLinearUnreducedCrossEntropyLoss,
51
+ FusedLinearDiffusionCrossEntropyLoss)
52
+ from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
53
+ from torch.distributed.tensor import DTensor
54
+
55
+ import torch.nn.functional as F
56
+ try:
57
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
58
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
59
+ except:
60
+ pass
61
+
62
+
63
+ def dtensor2local(dtensor):
64
+ if isinstance(dtensor, DTensor):
65
+ return dtensor.to_local()
66
+ else:
67
+ return dtensor
68
+
69
+
70
+ if is_torch_flex_attn_available():
71
+ from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
72
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
73
+
74
+
75
+ logger = logging.get_logger(__name__)
76
+
77
+
78
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
79
+ def fused_flex_attention(query, key, value, attention_mask=None, **kwargs):
80
+ return flex_attention(query, key, value, block_mask=attention_mask, **kwargs)
81
+
82
+
83
+ @use_kernel_forward_from_hub("RMSNorm")
84
+ class Qwen3RMSNorm(nn.Module):
85
+ def __init__(self, hidden_size, eps=1e-6):
86
+ """
87
+ Qwen3RMSNorm is equivalent to T5LayerNorm
88
+ """
89
+ super().__init__()
90
+ self.weight = nn.Parameter(torch.ones(hidden_size))
91
+ self.variance_epsilon = eps
92
+
93
+ def forward(self, hidden_states):
94
+ weight = dtensor2local(self.weight)
95
+ '''
96
+ return flash_rms_norm(hidden_states, weight=weight, bias=None, eps=self.variance_epsilon)
97
+ '''
98
+ input_dtype = hidden_states.dtype
99
+ hidden_states = hidden_states.to(torch.float32)
100
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
101
+ hidden_states = hidden_states * \
102
+ torch.rsqrt(variance + self.variance_epsilon)
103
+ return weight * hidden_states.to(input_dtype)
104
+
105
+
106
+ def extra_repr(self):
107
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
108
+
109
+
110
+ class Qwen3MLP(nn.Module):
111
+ def __init__(self, config):
112
+ super().__init__()
113
+ self.config = config
114
+ self.hidden_size = config.hidden_size
115
+ self.intermediate_size = config.intermediate_size
116
+ self.gate_proj = nn.Linear(
117
+ self.hidden_size, self.intermediate_size, bias=False)
118
+ self.up_proj = nn.Linear(
119
+ self.hidden_size, self.intermediate_size, bias=False)
120
+ self.down_proj = nn.Linear(
121
+ self.intermediate_size, self.hidden_size, bias=False)
122
+ self.act_fn = ACT2FN[config.hidden_act]
123
+
124
+ def forward(self, x):
125
+ down_proj_weight = dtensor2local(self.down_proj.weight)
126
+ down_proj_bias = dtensor2local(self.down_proj.bias)
127
+ # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
128
+ down_proj = swiglu_linear(self.gate_proj(x), self.up_proj(x),
129
+ down_proj_weight, down_proj_bias)
130
+ return down_proj
131
+
132
+
133
+ def rotate_half(x):
134
+ """Rotates half the hidden dims of the input."""
135
+ x1 = x[..., : x.shape[-1] // 2]
136
+ x2 = x[..., x.shape[-1] // 2:]
137
+ return torch.cat((-x2, x1), dim=-1)
138
+
139
+
140
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
141
+ """Applies Rotary Position Embedding to the query and key tensors.
142
+
143
+ Args:
144
+ q (`torch.Tensor`): The query tensor.
145
+ k (`torch.Tensor`): The key tensor.
146
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
147
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
148
+ position_ids (`torch.Tensor`, *optional*):
149
+ Deprecated and unused.
150
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
151
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
152
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
153
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
154
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
155
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
156
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
157
+ Returns:
158
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
159
+ """
160
+ cos = cos.unsqueeze(unsqueeze_dim)
161
+ sin = sin.unsqueeze(unsqueeze_dim)
162
+ q_embed = (q * cos) + (rotate_half(q) * sin)
163
+ k_embed = (k * cos) + (rotate_half(k) * sin)
164
+ return q_embed, k_embed
165
+
166
+
167
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
168
+ """
169
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
170
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
171
+ """
172
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
173
+ if n_rep == 1:
174
+ return hidden_states
175
+ hidden_states = hidden_states[:, :, None, :, :].expand(
176
+ batch, num_key_value_heads, n_rep, slen, head_dim)
177
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
178
+
179
+
180
+ def eager_attention_forward(
181
+ module: nn.Module,
182
+ query: torch.Tensor,
183
+ key: torch.Tensor,
184
+ value: torch.Tensor,
185
+ attention_mask: Optional[torch.Tensor],
186
+ scaling: float,
187
+ dropout: float = 0.0,
188
+ **kwargs,
189
+ ):
190
+ key_states = repeat_kv(key, module.num_key_value_groups)
191
+ value_states = repeat_kv(value, module.num_key_value_groups)
192
+
193
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
194
+ if attention_mask is not None:
195
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
196
+ attn_weights = attn_weights + causal_mask
197
+
198
+ attn_weights = nn.functional.softmax(
199
+ attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
200
+ attn_weights = nn.functional.dropout(
201
+ attn_weights, p=dropout, training=module.training)
202
+ attn_output = torch.matmul(attn_weights, value_states)
203
+ attn_output = attn_output.transpose(1, 2).contiguous()
204
+
205
+ return attn_output, attn_weights
206
+
207
+
208
+ class Qwen3Attention(nn.Module):
209
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
210
+
211
+ def __init__(self, config: Qwen3Config, layer_idx: int):
212
+ super().__init__()
213
+ self.config = config
214
+ self.layer_idx = layer_idx
215
+ self.head_dim = getattr(
216
+ config, "head_dim", config.hidden_size // config.num_attention_heads)
217
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
218
+ self.scaling = self.head_dim**-0.5
219
+ self.attention_dropout = config.attention_dropout
220
+ self.is_causal = True
221
+
222
+ self.hidden_size = config.hidden_size
223
+ self.num_attention_heads = config.num_attention_heads
224
+ self.num_key_value_heads = config.num_key_value_heads
225
+
226
+ self.q_proj = nn.Linear(
227
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
228
+ )
229
+ self.k_proj = nn.Linear(
230
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
231
+ )
232
+ self.v_proj = nn.Linear(
233
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
234
+ )
235
+ self.o_proj = nn.Linear(
236
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
237
+ )
238
+ # unlike olmo, only on the head dim!
239
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
240
+ # thus post q_norm does not need reshape
241
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
242
+ self.sliding_window = config.sliding_window
243
+ if not (
244
+ self.config.use_sliding_window
245
+ and getattr(self.config, "sliding_window", None) is not None
246
+ and self.layer_idx >= self.config.max_window_layers
247
+ ):
248
+ self.sliding_window = None
249
+
250
+ def forward(
251
+ self,
252
+ hidden_states: torch.Tensor,
253
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
254
+ attention_mask: Optional[torch.Tensor],
255
+ past_key_value: Optional[Cache] = None,
256
+ cache_position: Optional[torch.LongTensor] = None,
257
+ **kwargs: Unpack[FlashAttentionKwargs],
258
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
259
+ input_shape = hidden_states.shape[:-1]
260
+ bsz, q_len = input_shape
261
+ hidden_shape = (*input_shape, -1, self.head_dim)
262
+
263
+ query_states = self.q_norm(self.q_proj(
264
+ hidden_states).view(hidden_shape)).transpose(1, 2)
265
+ key_states = self.k_norm(self.k_proj(
266
+ hidden_states).view(hidden_shape)).transpose(1, 2)
267
+ value_states = self.v_proj(hidden_states).view(
268
+ hidden_shape).transpose(1, 2)
269
+
270
+ cos, sin = position_embeddings
271
+ query_states, key_states = apply_rotary_pos_emb(
272
+ query_states, key_states, cos, sin)
273
+
274
+ if past_key_value is not None and kwargs.get("store_kv", False):
275
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
276
+ key_states, value_states = past_key_value.update(
277
+ key_states, value_states, self.layer_idx)
278
+ elif past_key_value is not None and not kwargs.get("store_kv", False) and len(past_key_value) > self.layer_idx:# 只取不存
279
+ past_key_states, past_value_states = past_key_value[self.layer_idx]
280
+ key_states = torch.cat(
281
+ [past_key_states, key_states], dim=-2
282
+ )
283
+ value_states = torch.cat(
284
+ [past_value_states, value_states], dim=-2
285
+ )
286
+ # if past_key_value is not None:
287
+ # # sin and cos are specific to RoPE models; cache_position needed for the static cache
288
+ # cache_kwargs = {"sin": sin, "cos": cos,
289
+ # "cache_position": cache_position}
290
+ # key_states, value_states = past_key_value.update(
291
+ # key_states, value_states, self.layer_idx, cache_kwargs)
292
+
293
+ # attention_interface: Callable = eager_attention_forward
294
+ # if self.config._attn_implementation != "eager":
295
+ # if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
296
+ # logger.warning_once(
297
+ # "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
298
+ # 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
299
+ # )
300
+ # else:
301
+ # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
302
+
303
+ # if self.config._attn_implementation == 'flex_attention':
304
+ # # Although `AttentionInterface` has `flex_attention_forward` implementation,
305
+ # # we still use our customized `fused_flex_attention`
306
+ # pad_length = kwargs.get("pad_length", None)
307
+ # if pad_length is not None:
308
+ # # Used for SFT (packing + varlen), seq_len changes at each step
309
+ # # seq_len must be divisible by BLOCK_SIZE in flex attn
310
+ # pad_q = torch.zeros(
311
+ # bsz, self.num_attention_heads, pad_length, self.head_dim, device=query_states.device, dtype=query_states.dtype)
312
+ # pad_kv = torch.zeros(
313
+ # bsz, self.num_key_value_heads, pad_length, self.head_dim, device=query_states.device, dtype=query_states.dtype)
314
+ # attn_output, attn_weights = fused_flex_attention(
315
+ # query=torch.cat([query_states, pad_q], dim=2),
316
+ # key=torch.cat([key_states, pad_kv], dim=2),
317
+ # value=torch.cat([value_states, pad_kv], dim=2),
318
+ # attention_mask=attention_mask,
319
+ # enable_gqa=True,
320
+ # scale=self.scaling,
321
+ # return_lse=True
322
+ # )
323
+ # attn_output = attn_output[..., :q_len,
324
+ # :].transpose(1, 2).contiguous()
325
+ # attn_weights = attn_weights.to(value_states.dtype)
326
+ # else:
327
+ # attn_output, attn_weights = fused_flex_attention(
328
+ # query=query_states,
329
+ # key=key_states,
330
+ # value=value_states,
331
+ # attention_mask=attention_mask,
332
+ # enable_gqa=True,
333
+ # scale=self.scaling,
334
+ # return_lse=True
335
+ # )
336
+ # attn_output = attn_output.transpose(1, 2).contiguous()
337
+ # attn_weights = attn_weights.to(value_states.dtype)
338
+ # else:
339
+ # attn_output, attn_weights = attention_interface(
340
+ # self,
341
+ # query_states,
342
+ # key_states,
343
+ # value_states,
344
+ # attention_mask,
345
+ # dropout=0.0 if not self.training else self.attention_dropout,
346
+ # scaling=self.scaling,
347
+ # sliding_window=self.sliding_window, # diff with Llama
348
+ # **kwargs,
349
+ # )
350
+ # q: (b, h, l, d); k,v: (b, h', l, d); attn_output: (b, l, h, d);
351
+ # key_states = repeat_kv(key_states, 2)
352
+ # value_states = repeat_kv(value_states, 2)
353
+ attention_mask = attention_mask.bool() if attention_mask is not None else None
354
+ if torch.all(attention_mask): # 属于 decoding 阶段
355
+ query_states = query_states.transpose(1, 2)
356
+ key_states = key_states.transpose(1, 2)
357
+ value_states = value_states.transpose(1, 2)
358
+ attn_output = flash_attn_func(
359
+ query_states,
360
+ key_states,
361
+ value_states,
362
+ causal=False,
363
+ softmax_scale=self.scaling)
364
+
365
+ else:
366
+ attn_output = F.scaled_dot_product_attention(
367
+ query=query_states,
368
+ key=key_states,
369
+ value=value_states,
370
+ attn_mask=attention_mask,
371
+ is_causal=False,
372
+ scale=self.scaling,
373
+ enable_gqa=True)
374
+ attn_output = attn_output.transpose(1, 2).contiguous()
375
+
376
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
377
+ attn_output = self.o_proj(attn_output)
378
+ return attn_output, None #, attn_weights
379
+
380
+
381
+ class Qwen3DecoderLayer(GradientCheckpointingLayer):
382
+ def __init__(self, config: Qwen3Config, layer_idx: int):
383
+ super().__init__()
384
+ self.hidden_size = config.hidden_size
385
+ self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
386
+ self.mlp = Qwen3MLP(config)
387
+ self.input_layernorm = Qwen3RMSNorm(
388
+ config.hidden_size, eps=config.rms_norm_eps)
389
+ self.post_attention_layernorm = Qwen3RMSNorm(
390
+ config.hidden_size, eps=config.rms_norm_eps)
391
+ if (
392
+ config.sliding_window and config._attn_implementation != "flash_attention_2"
393
+ ): # diff with Llama is this warning
394
+ logger.warning_once(
395
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
396
+ "unexpected results may be encountered."
397
+ )
398
+
399
+ def forward(
400
+ self,
401
+ hidden_states: torch.Tensor,
402
+ attention_mask: Optional[torch.Tensor] = None,
403
+ position_ids: Optional[torch.LongTensor] = None,
404
+ past_key_value: Optional[Cache] = None,
405
+ output_attentions: Optional[bool] = False,
406
+ use_cache: Optional[bool] = False,
407
+ store_kv: Optional[bool] = False,
408
+ cache_position: Optional[torch.LongTensor] = None,
409
+ # necessary, but kept here for BC
410
+ position_embeddings: Optional[Tuple[torch.Tensor,
411
+ torch.Tensor]] = None,
412
+ **kwargs: Unpack[FlashAttentionKwargs],
413
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
414
+ residual = hidden_states
415
+ hidden_states = self.input_layernorm(hidden_states)
416
+
417
+ # Self Attention
418
+ hidden_states, self_attn_weights = self.self_attn(
419
+ hidden_states=hidden_states,
420
+ attention_mask=attention_mask,
421
+ position_ids=position_ids,
422
+ past_key_value=past_key_value,
423
+ output_attentions=output_attentions,
424
+ use_cache=use_cache,
425
+ store_kv=store_kv,
426
+ cache_position=cache_position,
427
+ position_embeddings=position_embeddings,
428
+ **kwargs,
429
+ )
430
+ hidden_states = residual + hidden_states
431
+
432
+ # Fully Connected
433
+ residual = hidden_states
434
+ hidden_states = self.post_attention_layernorm(hidden_states)
435
+ hidden_states = self.mlp(hidden_states)
436
+ hidden_states = residual + hidden_states
437
+
438
+ outputs = (hidden_states,)
439
+ if output_attentions:
440
+ outputs += (self_attn_weights,)
441
+
442
+ return outputs
443
+
444
+
445
+ @auto_docstring
446
+ class Qwen3PreTrainedModel(PreTrainedModel):
447
+ config_class = Qwen3Config
448
+ base_model_prefix = "model"
449
+ supports_gradient_checkpointing = True
450
+ _no_split_modules = ["Qwen3DecoderLayer"]
451
+ _skip_keys_device_placement = ["past_key_values"]
452
+ _supports_flash_attn_2 = True
453
+ _supports_sdpa = True
454
+ _supports_flex_attn = True
455
+ _supports_cache_class = True
456
+ _supports_quantized_cache = True
457
+ _supports_static_cache = True
458
+ _supports_attention_backend = True
459
+
460
+ def _init_weights(self, module):
461
+ std = self.config.initializer_range
462
+ if isinstance(module, nn.Linear):
463
+ module.weight.data.normal_(mean=0.0, std=std)
464
+ if module.bias is not None:
465
+ module.bias.data.zero_()
466
+ elif isinstance(module, nn.Embedding):
467
+ module.weight.data.normal_(mean=0.0, std=std)
468
+ if module.padding_idx is not None:
469
+ module.weight.data[module.padding_idx].zero_()
470
+ elif isinstance(module, Qwen3RMSNorm):
471
+ module.weight.data.fill_(1.0)
472
+
473
+
474
+ class Qwen3RotaryEmbedding(nn.Module):
475
+ def __init__(self, config: Qwen3Config, device=None):
476
+ super().__init__()
477
+ # BC: "rope_type" was originally "type"
478
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
479
+ self.rope_type = config.rope_scaling.get(
480
+ "rope_type", config.rope_scaling.get("type"))
481
+ else:
482
+ self.rope_type = "default"
483
+ self.max_seq_len_cached = config.max_position_embeddings
484
+ self.original_max_seq_len = config.max_position_embeddings
485
+
486
+ self.config = config
487
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
488
+
489
+ inv_freq, self.attention_scaling = self.rope_init_fn(
490
+ self.config, device)
491
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
492
+ self.original_inv_freq = self.inv_freq
493
+
494
+ @torch.no_grad()
495
+ # power user: used with advanced RoPE types (e.g. dynamic rope)
496
+ @dynamic_rope_update
497
+ def forward(self, x, position_ids):
498
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
499
+ position_ids.shape[0], -1, 1).to(x.device)
500
+ position_ids_expanded = position_ids[:, None, :].float()
501
+
502
+ device_type = x.device.type if isinstance(
503
+ x.device.type, str) and x.device.type != "mps" else "cpu"
504
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
505
+ freqs = (inv_freq_expanded.float() @
506
+ position_ids_expanded.float()).transpose(1, 2)
507
+ emb = torch.cat((freqs, freqs), dim=-1)
508
+ cos = emb.cos() * self.attention_scaling
509
+ sin = emb.sin() * self.attention_scaling
510
+
511
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
512
+
513
+
514
+ @auto_docstring
515
+ class Qwen3Model(Qwen3PreTrainedModel):
516
+ def __init__(self, config: Qwen3Config):
517
+ super().__init__(config)
518
+ self.padding_idx = config.pad_token_id
519
+ self.vocab_size = config.vocab_size
520
+
521
+ self.embed_tokens = nn.Embedding(
522
+ config.vocab_size, config.hidden_size, self.padding_idx)
523
+ self.layers = nn.ModuleList(
524
+ [Qwen3DecoderLayer(config, layer_idx)
525
+ for layer_idx in range(config.num_hidden_layers)]
526
+ )
527
+ self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
528
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config)
529
+ self.gradient_checkpointing = False
530
+
531
+ # Initialize weights and apply final processing
532
+ self.post_init()
533
+
534
+ def get_input_embeddings(self):
535
+ return self.embed_tokens
536
+
537
+ def set_input_embeddings(self, value):
538
+ self.embed_tokens = value
539
+
540
+ @can_return_tuple
541
+ @auto_docstring
542
+ def forward(
543
+ self,
544
+ input_ids: Optional[torch.LongTensor] = None,
545
+ attention_mask: Optional[torch.Tensor] = None,
546
+ position_ids: Optional[torch.LongTensor] = None,
547
+ past_key_values: Optional[Cache] = None,
548
+ inputs_embeds: Optional[torch.FloatTensor] = None,
549
+ use_cache: Optional[bool] = None,
550
+ store_kv: Optional[bool] = None,
551
+ output_attentions: Optional[bool] = None,
552
+ output_hidden_states: Optional[bool] = None,
553
+ cache_position: Optional[torch.LongTensor] = None,
554
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
555
+ ) -> BaseModelOutputWithPast:
556
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
557
+ output_hidden_states = (
558
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
559
+ )
560
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
561
+
562
+ if (input_ids is None) ^ (inputs_embeds is not None):
563
+ raise ValueError(
564
+ "You must specify exactly one of input_ids or inputs_embeds")
565
+
566
+ if self.gradient_checkpointing and self.training and use_cache:
567
+ logger.warning_once(
568
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
569
+ )
570
+ use_cache = False
571
+
572
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
573
+ if not isinstance(past_key_values, (type(None), Cache)):
574
+ raise ValueError(
575
+ "The `past_key_values` should be either a `Cache` object or `None`.")
576
+
577
+ if inputs_embeds is None:
578
+ inputs_embeds = self.embed_tokens(input_ids)
579
+
580
+ if use_cache and past_key_values is None:
581
+ past_key_values = DynamicCache()
582
+
583
+ if cache_position is None:
584
+ past_seen_tokens = past_key_values.get_seq_length(
585
+ ) if past_key_values is not None else 0
586
+ cache_position = torch.arange(
587
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
588
+ )
589
+
590
+ if position_ids is None:
591
+ position_ids = cache_position.unsqueeze(0)
592
+
593
+ # causal_mask = self._update_causal_mask(
594
+ # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
595
+ # )
596
+
597
+ hidden_states = inputs_embeds
598
+
599
+ # create position embeddings to be shared across the decoder layers
600
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
601
+
602
+ # decoder layers
603
+ all_hidden_states = () if output_hidden_states else None
604
+ all_self_attns = () if output_attentions else None
605
+
606
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
607
+ if output_hidden_states:
608
+ all_hidden_states += (hidden_states,)
609
+
610
+ layer_outputs = decoder_layer(
611
+ hidden_states,
612
+ attention_mask=attention_mask,
613
+ position_ids=position_ids,
614
+ past_key_value=past_key_values,
615
+ output_attentions=output_attentions,
616
+ use_cache=use_cache,
617
+ store_kv=store_kv,
618
+ cache_position=cache_position,
619
+ position_embeddings=position_embeddings,
620
+ **flash_attn_kwargs,
621
+ )
622
+
623
+ hidden_states = layer_outputs[0]
624
+
625
+ if output_attentions:
626
+ all_self_attns += (layer_outputs[1],)
627
+
628
+ hidden_states = self.norm(hidden_states)
629
+
630
+ # add hidden states from the last decoder layer
631
+ if output_hidden_states:
632
+ all_hidden_states += (hidden_states,)
633
+
634
+ return BaseModelOutputWithPast(
635
+ last_hidden_state=hidden_states,
636
+ past_key_values=past_key_values if use_cache else None,
637
+ hidden_states=all_hidden_states,
638
+ attentions=all_self_attns,
639
+ )
640
+
641
+ def _update_causal_mask(
642
+ self,
643
+ attention_mask: Union[torch.Tensor, "BlockMask"],
644
+ input_tensor: torch.Tensor,
645
+ cache_position: torch.Tensor,
646
+ past_key_values: Cache,
647
+ output_attentions: bool = False,
648
+ ):
649
+ if self.config._attn_implementation == "flash_attention_2":
650
+ if attention_mask is not None and past_key_values is not None:
651
+ is_padding_right = attention_mask[:, -
652
+ 1].sum().item() != input_tensor.size()[0]
653
+ if is_padding_right:
654
+ raise ValueError(
655
+ "You are attempting to perform batched generation with padding_side='right'"
656
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
657
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
658
+ )
659
+ if attention_mask is not None and 0.0 in attention_mask:
660
+ return attention_mask
661
+ return None
662
+ if self.config._attn_implementation == "flex_attention":
663
+ if isinstance(attention_mask, torch.Tensor):
664
+ seq_len_q, seq_len_kv = attention_mask.shape
665
+ assert seq_len_q == seq_len_kv, f"got {attention_mask.shape=}"
666
+ attention_mask = create_block_mask(
667
+ # 2d bool tensor, shape: [2*seqlen, 2*seqlen]
668
+ lambda b, h, q_idx, kv_idx: attention_mask[q_idx, kv_idx],
669
+ B=None, H=None, Q_LEN=seq_len_q, KV_LEN=seq_len_kv,
670
+ )
671
+ else:
672
+ # Here we pass in flex mask computed externally
673
+ assert isinstance(attention_mask, BlockMask)
674
+ return attention_mask
675
+
676
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
677
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
678
+ # to infer the attention mask.
679
+ past_seen_tokens = past_key_values.get_seq_length(
680
+ ) if past_key_values is not None else 0
681
+ using_static_cache = isinstance(past_key_values, StaticCache)
682
+ using_sliding_window_cache = isinstance(
683
+ past_key_values, SlidingWindowCache)
684
+
685
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
686
+ if (
687
+ self.config._attn_implementation == "sdpa"
688
+ and not (using_static_cache or using_sliding_window_cache)
689
+ and not output_attentions
690
+ ):
691
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
692
+ attention_mask,
693
+ inputs_embeds=input_tensor,
694
+ past_key_values_length=past_seen_tokens,
695
+ sliding_window=self.config.sliding_window,
696
+ is_training=self.training,
697
+ ):
698
+ return None
699
+
700
+ dtype = input_tensor.dtype
701
+ min_dtype = torch.finfo(dtype).min
702
+ sequence_length = input_tensor.shape[1]
703
+ # SlidingWindowCache or StaticCache
704
+ if using_sliding_window_cache or using_static_cache:
705
+ target_length = past_key_values.get_max_cache_shape()
706
+ # DynamicCache or no cache
707
+ else:
708
+ target_length = (
709
+ attention_mask.shape[-1]
710
+ if isinstance(attention_mask, torch.Tensor)
711
+ else past_seen_tokens + sequence_length + 1
712
+ )
713
+
714
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
715
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
716
+ attention_mask,
717
+ sequence_length=sequence_length,
718
+ target_length=target_length,
719
+ dtype=dtype,
720
+ cache_position=cache_position,
721
+ batch_size=input_tensor.shape[0],
722
+ config=self.config,
723
+ past_key_values=past_key_values,
724
+ )
725
+
726
+ if (
727
+ self.config._attn_implementation == "sdpa"
728
+ and attention_mask is not None
729
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
730
+ and not output_attentions
731
+ ):
732
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
733
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
734
+ # Details: https://github.com/pytorch/pytorch/issues/110213
735
+ causal_mask = AttentionMaskConverter._unmask_unattended(
736
+ causal_mask, min_dtype)
737
+
738
+ return causal_mask
739
+
740
+ @staticmethod
741
+ def _prepare_4d_causal_attention_mask_with_cache_position(
742
+ attention_mask: torch.Tensor,
743
+ sequence_length: int,
744
+ target_length: int,
745
+ dtype: torch.dtype,
746
+ cache_position: torch.Tensor,
747
+ batch_size: int,
748
+ config: Qwen3Config,
749
+ past_key_values: Cache,
750
+ ):
751
+ """
752
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
753
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
754
+
755
+ Args:
756
+ attention_mask (`torch.Tensor`):
757
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
758
+ sequence_length (`int`):
759
+ The sequence length being processed.
760
+ target_length (`int`):
761
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
762
+ dtype (`torch.dtype`):
763
+ The dtype to use for the 4D attention mask.
764
+ cache_position (`torch.Tensor`):
765
+ Indices depicting the position of the input sequence tokens in the sequence.
766
+ batch_size (`torch.Tensor`):
767
+ Batch size.
768
+ config (`Qwen3Config`):
769
+ The model's configuration class
770
+ past_key_values (`Cache`):
771
+ The cache class that is being used currently to generate
772
+ """
773
+ if attention_mask is not None and attention_mask.dim() == 4:
774
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
775
+ causal_mask = attention_mask
776
+ else:
777
+ min_dtype = torch.finfo(dtype).min
778
+ causal_mask = torch.full(
779
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
780
+ )
781
+ diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
782
+ -1, 1
783
+ )
784
+ text_config = config.get_text_config()
785
+ if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
786
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
787
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
788
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
789
+ sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
790
+ cache_position.reshape(-1, 1) -
791
+ text_config.sliding_window
792
+ )
793
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
794
+ causal_mask *= diagonal_attend_mask
795
+ causal_mask = causal_mask[None, None,
796
+ :, :].expand(batch_size, 1, -1, -1)
797
+ if attention_mask is not None:
798
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
799
+ if attention_mask.shape[-1] > target_length:
800
+ attention_mask = attention_mask[:, :target_length]
801
+ mask_length = attention_mask.shape[-1]
802
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
803
+ causal_mask.device
804
+ )
805
+ padding_mask = padding_mask == 0
806
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
807
+ padding_mask, min_dtype
808
+ )
809
+ return causal_mask
810
+
811
+
812
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
813
+ ...
814
+
815
+
816
+ @auto_docstring
817
+ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
818
+ _tied_weights_keys = ["lm_head.weight"]
819
+ _tp_plan = {"lm_head": "colwise_rep"}
820
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
821
+
822
+ def __init__(self, config):
823
+ super().__init__(config)
824
+ self.model = Qwen3Model(config)
825
+ self.vocab_size = config.vocab_size
826
+ self.lm_head = nn.Linear(
827
+ config.hidden_size, config.vocab_size, bias=False)
828
+
829
+ # Initialize weights and apply final processing
830
+ self.post_init()
831
+
832
+ def get_input_embeddings(self):
833
+ return self.model.embed_tokens
834
+
835
+ def set_input_embeddings(self, value):
836
+ self.model.embed_tokens = value
837
+
838
+ def get_output_embeddings(self):
839
+ return self.lm_head
840
+
841
+ def set_output_embeddings(self, new_embeddings):
842
+ self.lm_head = new_embeddings
843
+
844
+ def set_decoder(self, decoder):
845
+ self.model = decoder
846
+
847
+ def get_decoder(self):
848
+ return self.model
849
+
850
+ @can_return_tuple
851
+ @auto_docstring
852
+ def forward(
853
+ self,
854
+ input_ids: Optional[torch.LongTensor] = None,
855
+ attention_mask: Optional[torch.Tensor] = None,
856
+ position_ids: Optional[torch.LongTensor] = None,
857
+ past_key_values: Optional[Cache] = None,
858
+ inputs_embeds: Optional[torch.FloatTensor] = None,
859
+ labels: Optional[torch.LongTensor] = None,
860
+ use_cache: Optional[bool] = None,
861
+ output_attentions: Optional[bool] = None,
862
+ output_hidden_states: Optional[bool] = None,
863
+ cache_position: Optional[torch.LongTensor] = None,
864
+ logits_to_keep: Union[int, torch.Tensor] = 0,
865
+ **kwargs: Unpack[KwargsForCausalLM],
866
+ ) -> CausalLMOutputWithPast:
867
+ r"""
868
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
869
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
870
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
871
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
872
+
873
+ Example:
874
+
875
+ ```python
876
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
877
+
878
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
879
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
880
+
881
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
882
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
883
+
884
+ >>> # Generate
885
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
886
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
887
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
888
+ ```"""
889
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
890
+ output_hidden_states = (
891
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
892
+ )
893
+
894
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
895
+ outputs: BaseModelOutputWithPast = self.model(
896
+ input_ids=input_ids,
897
+ attention_mask=attention_mask,
898
+ position_ids=position_ids,
899
+ past_key_values=past_key_values,
900
+ inputs_embeds=inputs_embeds,
901
+ use_cache=use_cache,
902
+ output_attentions=output_attentions,
903
+ output_hidden_states=output_hidden_states,
904
+ cache_position=cache_position,
905
+ **kwargs,
906
+ )
907
+
908
+ hidden_states = outputs.last_hidden_state
909
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
910
+ slice_indices = slice(-logits_to_keep,
911
+ None) if isinstance(logits_to_keep, int) else logits_to_keep
912
+ hidden_states = hidden_states[:, slice_indices, :].contiguous()
913
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
914
+ if fuse_linear_and_cross_entropy:
915
+ # When using fused_linear_ce_loss, we do not compute the whole logits on HBM
916
+ logits = None
917
+ else:
918
+ logits = self.lm_head(hidden_states)
919
+
920
+ loss = None
921
+ if labels is not None:
922
+ if self.config.fuse_cross_entropy:
923
+ if fuse_linear_and_cross_entropy:
924
+ # Note: We use reduction='sum'
925
+ # For 'mean' reduction, gradients are normalized by number of *non-ignored* elements
926
+ # mean_loss = sum_loss / num_non_ignored_tokens, instead of all tokens (labels != -100)
927
+ loss_fct = FusedLinearDiffusionCrossEntropyLoss(
928
+ reduction='sum')
929
+ else:
930
+ loss_fct = FusedCrossEntropyLoss(
931
+ reduction='sum', inplace_backward=True)
932
+ else:
933
+ loss_fct = nn.CrossEntropyLoss() # nn.CE
934
+
935
+ if fuse_linear_and_cross_entropy:
936
+ p_mask = kwargs.get('p_mask', None)
937
+ # loss: tuple of (sum_loss, unreduced_loss)
938
+ lm_head_weight = dtensor2local(self.lm_head.weight)
939
+ lm_head_bias = dtensor2local(self.lm_head.bias)
940
+ loss = loss_fct(
941
+ x=hidden_states, # `view(-1, V)` inside the kernel
942
+ target=labels,
943
+ weight=lm_head_weight,
944
+ bias=lm_head_bias,
945
+ p_mask=p_mask,
946
+ )
947
+ else:
948
+ raise RuntimeError("Do not support yet!")
949
+ loss = loss_fct(
950
+ logits.view(-1, self.config.vocab_size), labels.view(-1))
951
+
952
+ return CausalLMOutputWithPast(
953
+ loss=loss,
954
+ logits=logits,
955
+ past_key_values=outputs.past_key_values,
956
+ hidden_states=outputs.hidden_states,
957
+ attentions=outputs.attentions,
958
+ )
959
+
960
+
961
+ @auto_docstring(
962
+ custom_intro="""
963
+ The Qwen3 Model transformer with a sequence classification head on top (linear layer).
964
+
965
+ [`Qwen3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
966
+ (e.g. GPT-2) do.
967
+
968
+ Since it does classification on the last token, it requires to know the position of the last token. If a
969
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
970
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
971
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
972
+ each row of the batch).
973
+ """
974
+ )
975
+ class Qwen3ForSequenceClassification(Qwen3PreTrainedModel):
976
+ def __init__(self, config):
977
+ super().__init__(config)
978
+ self.num_labels = config.num_labels
979
+ self.model = Qwen3Model(config)
980
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
981
+
982
+ # Initialize weights and apply final processing
983
+ self.post_init()
984
+
985
+ def get_input_embeddings(self):
986
+ return self.model.embed_tokens
987
+
988
+ def set_input_embeddings(self, value):
989
+ self.model.embed_tokens = value
990
+
991
+ @can_return_tuple
992
+ @auto_docstring
993
+ def forward(
994
+ self,
995
+ input_ids: Optional[torch.LongTensor] = None,
996
+ attention_mask: Optional[torch.Tensor] = None,
997
+ position_ids: Optional[torch.LongTensor] = None,
998
+ past_key_values: Optional[Cache] = None,
999
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1000
+ labels: Optional[torch.LongTensor] = None,
1001
+ use_cache: Optional[bool] = None,
1002
+ output_attentions: Optional[bool] = None,
1003
+ output_hidden_states: Optional[bool] = None,
1004
+ ) -> SequenceClassifierOutputWithPast:
1005
+ r"""
1006
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1007
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1008
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1009
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1010
+ """
1011
+
1012
+ transformer_outputs: BaseModelOutputWithPast = self.model(
1013
+ input_ids,
1014
+ attention_mask=attention_mask,
1015
+ position_ids=position_ids,
1016
+ past_key_values=past_key_values,
1017
+ inputs_embeds=inputs_embeds,
1018
+ use_cache=use_cache,
1019
+ output_attentions=output_attentions,
1020
+ output_hidden_states=output_hidden_states,
1021
+ )
1022
+ hidden_states = transformer_outputs.last_hidden_state
1023
+ logits = self.score(hidden_states)
1024
+
1025
+ if input_ids is not None:
1026
+ batch_size = input_ids.shape[0]
1027
+ else:
1028
+ batch_size = inputs_embeds.shape[0]
1029
+
1030
+ if self.config.pad_token_id is None and batch_size != 1:
1031
+ raise ValueError(
1032
+ "Cannot handle batch sizes > 1 if no padding token is defined.")
1033
+ if self.config.pad_token_id is None:
1034
+ last_non_pad_token = -1
1035
+ elif input_ids is not None:
1036
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1037
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(
1038
+ logits.device, torch.int32)
1039
+ token_indices = torch.arange(
1040
+ input_ids.shape[-1], device=logits.device, dtype=torch.int32)
1041
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1042
+ else:
1043
+ last_non_pad_token = -1
1044
+ logger.warning_once(
1045
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1046
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1047
+ )
1048
+
1049
+ pooled_logits = logits[torch.arange(
1050
+ batch_size, device=logits.device), last_non_pad_token]
1051
+
1052
+ loss = None
1053
+ if labels is not None:
1054
+ loss = self.loss_function(
1055
+ logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1056
+
1057
+ return SequenceClassifierOutputWithPast(
1058
+ loss=loss,
1059
+ logits=pooled_logits,
1060
+ past_key_values=transformer_outputs.past_key_values,
1061
+ hidden_states=transformer_outputs.hidden_states,
1062
+ attentions=transformer_outputs.attentions,
1063
+ )
1064
+
1065
+
1066
+ @auto_docstring
1067
+ class Qwen3ForTokenClassification(Qwen3PreTrainedModel):
1068
+ def __init__(self, config):
1069
+ super().__init__(config)
1070
+ self.num_labels = config.num_labels
1071
+ self.model = Qwen3Model(config)
1072
+ if getattr(config, "classifier_dropout", None) is not None:
1073
+ classifier_dropout = config.classifier_dropout
1074
+ elif getattr(config, "hidden_dropout", None) is not None:
1075
+ classifier_dropout = config.hidden_dropout
1076
+ else:
1077
+ classifier_dropout = 0.1
1078
+ self.dropout = nn.Dropout(classifier_dropout)
1079
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1080
+
1081
+ # Initialize weights and apply final processing
1082
+ self.post_init()
1083
+
1084
+ def get_input_embeddings(self):
1085
+ return self.model.embed_tokens
1086
+
1087
+ def set_input_embeddings(self, value):
1088
+ self.model.embed_tokens = value
1089
+
1090
+ @can_return_tuple
1091
+ @auto_docstring
1092
+ def forward(
1093
+ self,
1094
+ input_ids: Optional[torch.LongTensor] = None,
1095
+ attention_mask: Optional[torch.Tensor] = None,
1096
+ position_ids: Optional[torch.LongTensor] = None,
1097
+ past_key_values: Optional[Cache] = None,
1098
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1099
+ labels: Optional[torch.LongTensor] = None,
1100
+ use_cache: Optional[bool] = None,
1101
+ output_attentions: Optional[bool] = None,
1102
+ output_hidden_states: Optional[bool] = None,
1103
+ ) -> TokenClassifierOutput:
1104
+ r"""
1105
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1106
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1107
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1108
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1109
+ """
1110
+
1111
+ outputs: BaseModelOutputWithPast = self.model(
1112
+ input_ids,
1113
+ attention_mask=attention_mask,
1114
+ position_ids=position_ids,
1115
+ past_key_values=past_key_values,
1116
+ inputs_embeds=inputs_embeds,
1117
+ use_cache=use_cache,
1118
+ output_attentions=output_attentions,
1119
+ output_hidden_states=output_hidden_states,
1120
+ )
1121
+ sequence_output = outputs.last_hidden_state
1122
+ sequence_output = self.dropout(sequence_output)
1123
+ logits = self.score(sequence_output)
1124
+
1125
+ loss = None
1126
+ if labels is not None:
1127
+ loss = self.loss_function(logits, labels, self.config)
1128
+
1129
+ return TokenClassifierOutput(
1130
+ loss=loss,
1131
+ logits=logits,
1132
+ hidden_states=outputs.hidden_states,
1133
+ attentions=outputs.attentions,
1134
+ )
1135
+
1136
+
1137
+ @auto_docstring
1138
+ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel):
1139
+ base_model_prefix = "transformer"
1140
+
1141
+ def __init__(self, config):
1142
+ super().__init__(config)
1143
+ self.transformer = Qwen3Model(config)
1144
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1145
+
1146
+ # Initialize weights and apply final processing
1147
+ self.post_init()
1148
+
1149
+ def get_input_embeddings(self):
1150
+ return self.transformer.embed_tokens
1151
+
1152
+ def set_input_embeddings(self, value):
1153
+ self.transformer.embed_tokens = value
1154
+
1155
+ @can_return_tuple
1156
+ @auto_docstring
1157
+ def forward(
1158
+ self,
1159
+ input_ids: Optional[torch.LongTensor] = None,
1160
+ attention_mask: Optional[torch.Tensor] = None,
1161
+ position_ids: Optional[torch.LongTensor] = None,
1162
+ past_key_values: Optional[Cache] = None,
1163
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1164
+ start_positions: Optional[torch.LongTensor] = None,
1165
+ end_positions: Optional[torch.LongTensor] = None,
1166
+ output_attentions: Optional[bool] = None,
1167
+ output_hidden_states: Optional[bool] = None,
1168
+ **kwargs,
1169
+ ) -> QuestionAnsweringModelOutput:
1170
+ outputs: BaseModelOutputWithPast = self.transformer(
1171
+ input_ids,
1172
+ attention_mask=attention_mask,
1173
+ position_ids=position_ids,
1174
+ past_key_values=past_key_values,
1175
+ inputs_embeds=inputs_embeds,
1176
+ output_attentions=output_attentions,
1177
+ output_hidden_states=output_hidden_states,
1178
+ )
1179
+
1180
+ sequence_output = outputs.last_hidden_state
1181
+
1182
+ logits = self.qa_outputs(sequence_output)
1183
+ start_logits, end_logits = logits.split(1, dim=-1)
1184
+ start_logits = start_logits.squeeze(-1).contiguous()
1185
+ end_logits = end_logits.squeeze(-1).contiguous()
1186
+
1187
+ loss = None
1188
+ if start_positions is not None and end_positions is not None:
1189
+ loss = self.loss_function(
1190
+ start_logits, end_logits, start_positions, end_positions, **kwargs)
1191
+
1192
+ return QuestionAnsweringModelOutput(
1193
+ loss=loss,
1194
+ start_logits=start_logits,
1195
+ end_logits=end_logits,
1196
+ hidden_states=outputs.hidden_states,
1197
+ attentions=outputs.attentions,
1198
+ )
1199
+
1200
+
1201
+ __all__ = [
1202
+ "Qwen3ForCausalLM",
1203
+ "Qwen3ForQuestionAnswering",
1204
+ "Qwen3Model",
1205
+ "Qwen3PreTrainedModel",
1206
+ "Qwen3ForSequenceClassification",
1207
+ "Qwen3ForTokenClassification",
1208
+ ]
modeling_qwen3_origin.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_qwen3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Callable, Optional, Tuple, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ from einops import rearrange
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.integrations import use_kernel_forward_from_hub
32
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
33
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
34
+ from transformers.modeling_layers import GradientCheckpointingLayer
35
+ from transformers.modeling_outputs import (
36
+ BaseModelOutputWithPast,
37
+ CausalLMOutputWithPast,
38
+ QuestionAnsweringModelOutput,
39
+ SequenceClassifierOutputWithPast,
40
+ TokenClassifierOutput,
41
+ )
42
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
43
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
44
+ from transformers.processing_utils import Unpack
45
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
46
+ from .configuration_qwen3 import Qwen3Config
47
+
48
+ from torch.nn import CrossEntropyLoss
49
+ from fla.modules.activations import swiglu_linear
50
+ from fla.modules import FusedLinearDiffusionCrossEntropyLoss
51
+ from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
52
+
53
+ if is_torch_flex_attn_available():
54
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention
55
+
56
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
57
+
58
+ # flex attn needs torch.compile to accelerate
59
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
60
+ def fused_flex_attention(query, key, value, attention_mask, **kwargs):
61
+ return flex_attention(query, key, value, block_mask=attention_mask, **kwargs)
62
+
63
+ logger = logging.get_logger(__name__)
64
+
65
+
66
+ @use_kernel_forward_from_hub("RMSNorm")
67
+ class Qwen3RMSNorm(nn.Module):
68
+ def __init__(self, hidden_size, eps=1e-6):
69
+ """
70
+ Qwen3RMSNorm is equivalent to T5LayerNorm
71
+ """
72
+ super().__init__()
73
+ self.weight = nn.Parameter(torch.ones(hidden_size))
74
+ self.variance_epsilon = eps
75
+
76
+ def forward(self, hidden_states):
77
+ input_dtype = hidden_states.dtype
78
+ # hidden_states = hidden_states.to(torch.float32)
79
+ # variance = hidden_states.pow(2).mean(-1, keepdim=True)
80
+ # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
81
+ return flash_rms_norm(
82
+ x=hidden_states, weight=self.weight, bias=None, eps=self.variance_epsilon).to(input_dtype)
83
+
84
+ def extra_repr(self):
85
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
86
+
87
+
88
+ class Qwen3MLP(nn.Module):
89
+ def __init__(self, config):
90
+ super().__init__()
91
+ self.config = config
92
+ self.hidden_size = config.hidden_size
93
+ self.intermediate_size = config.intermediate_size
94
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
95
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
96
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
97
+ self.act_fn = ACT2FN[config.hidden_act]
98
+
99
+ def forward(self, x):
100
+ # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
101
+ down_proj = swiglu_linear(self.gate_proj(x), self.up_proj(x),
102
+ self.down_proj.weight, self.down_proj.bias)
103
+ return down_proj
104
+
105
+
106
+ def rotate_half(x):
107
+ """Rotates half the hidden dims of the input."""
108
+ x1 = x[..., : x.shape[-1] // 2]
109
+ x2 = x[..., x.shape[-1] // 2 :]
110
+ return torch.cat((-x2, x1), dim=-1)
111
+
112
+
113
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
114
+ """Applies Rotary Position Embedding to the query and key tensors.
115
+
116
+ Args:
117
+ q (`torch.Tensor`): The query tensor.
118
+ k (`torch.Tensor`): The key tensor.
119
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
120
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
121
+ position_ids (`torch.Tensor`, *optional*):
122
+ Deprecated and unused.
123
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
124
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
125
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
126
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
127
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
128
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
129
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
130
+ Returns:
131
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
132
+ """
133
+ cos = cos.unsqueeze(unsqueeze_dim)
134
+ sin = sin.unsqueeze(unsqueeze_dim)
135
+ q_embed = (q * cos) + (rotate_half(q) * sin)
136
+ k_embed = (k * cos) + (rotate_half(k) * sin)
137
+ return q_embed, k_embed
138
+
139
+
140
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
141
+ """
142
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
143
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
144
+ """
145
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
146
+ if n_rep == 1:
147
+ return hidden_states
148
+ hidden_states = hidden_states[:, :, None, :, :].expand(
149
+ batch, num_key_value_heads, n_rep, slen, head_dim)
150
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
151
+
152
+
153
+ def eager_attention_forward(
154
+ module: nn.Module,
155
+ query: torch.Tensor,
156
+ key: torch.Tensor,
157
+ value: torch.Tensor,
158
+ attention_mask: Optional[torch.Tensor],
159
+ scaling: float,
160
+ dropout: float = 0.0,
161
+ **kwargs,
162
+ ):
163
+ key_states = repeat_kv(key, module.num_key_value_groups)
164
+ value_states = repeat_kv(value, module.num_key_value_groups)
165
+
166
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
167
+ if attention_mask is not None:
168
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
169
+ attn_weights = attn_weights + causal_mask
170
+
171
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
172
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
173
+ attn_output = torch.matmul(attn_weights, value_states)
174
+ attn_output = attn_output.transpose(1, 2).contiguous()
175
+
176
+ return attn_output, attn_weights
177
+
178
+
179
+ class Qwen3Attention(nn.Module):
180
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
181
+
182
+ def __init__(self, config: Qwen3Config, layer_idx: int):
183
+ super().__init__()
184
+ self.config = config
185
+ self.layer_idx = layer_idx
186
+ self.num_attention_heads = config.num_attention_heads
187
+ self.num_key_value_heads = config.num_key_value_heads
188
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
189
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
190
+ self.scaling = self.head_dim**-0.5
191
+ self.attention_dropout = config.attention_dropout
192
+ self.is_causal = False
193
+
194
+ self.q_proj = nn.Linear(
195
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
196
+ )
197
+ self.k_proj = nn.Linear(
198
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
199
+ )
200
+ self.v_proj = nn.Linear(
201
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
202
+ )
203
+ self.o_proj = nn.Linear(
204
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
205
+ )
206
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
207
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
208
+ self.sliding_window = config.sliding_window
209
+ if not (
210
+ self.config.use_sliding_window
211
+ and getattr(self.config, "sliding_window", None) is not None
212
+ and self.layer_idx >= self.config.max_window_layers
213
+ ):
214
+ self.sliding_window = None
215
+
216
+ def forward(
217
+ self,
218
+ hidden_states: torch.Tensor,
219
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
220
+ attention_mask: Optional[torch.Tensor],
221
+ past_key_value: Optional[Cache] = None,
222
+ cache_position: Optional[torch.LongTensor] = None,
223
+ **kwargs: Unpack[FlashAttentionKwargs],
224
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
225
+ input_shape = hidden_states.shape[:-1]
226
+ bsz, q_len = input_shape
227
+ hidden_shape = (*input_shape, -1, self.head_dim)
228
+
229
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
230
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
231
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
232
+
233
+ cos, sin = position_embeddings
234
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
235
+
236
+ if past_key_value is not None:
237
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
238
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
239
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
240
+
241
+ attention_interface: Callable = eager_attention_forward
242
+ if self.config._attn_implementation != "eager":
243
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
244
+ logger.warning_once(
245
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
246
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
247
+ )
248
+ else:
249
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
250
+
251
+ if self.config._attn_implementation == 'flex_attention':
252
+ # Although there exists `flex_attention_forward` in `AttentionInterface`,
253
+ # we still use our customized `fused_flex_attention` for debugging.
254
+ pad_length = kwargs.get("pad_length", 0)
255
+ # Used for SFT (packing + varlen), seq_len changes at each step
256
+ # seq_len must be divisible by BLOCK_SIZE in flex attn
257
+ pad_q = torch.zeros(
258
+ bsz, self.num_attention_heads, pad_length, self.head_dim, device=query_states.device, dtype=query_states.dtype)
259
+ pad_kv = torch.zeros(
260
+ bsz, self.num_key_value_heads, pad_length, self.head_dim, device=query_states.device, dtype=query_states.dtype)
261
+ attn_output, attn_weights = fused_flex_attention(
262
+ query=torch.cat([query_states, pad_q], dim=2),
263
+ key=torch.cat([key_states, pad_kv], dim=2),
264
+ value=torch.cat([value_states, pad_kv], dim=2),
265
+ attention_mask=attention_mask,
266
+ enable_gqa=True,
267
+ scale=self.scaling,
268
+ return_lse=True
269
+ )
270
+
271
+ attn_output = attn_output[..., :q_len, :].contiguous()
272
+ attn_weights = attn_weights.to(value_states.dtype)
273
+ attn_output = rearrange(attn_output, 'b h l d -> b l (h d)')
274
+ else:
275
+ attn_output, attn_weights = attention_interface(
276
+ self,
277
+ query_states,
278
+ key_states,
279
+ value_states,
280
+ attention_mask,
281
+ dropout=0.0 if not self.training else self.attention_dropout,
282
+ scaling=self.scaling,
283
+ sliding_window=self.sliding_window, # diff with Llama
284
+ **kwargs,
285
+ ) # output: [b, l, h, d]
286
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
287
+ attn_output = self.o_proj(attn_output)
288
+ return attn_output, attn_weights
289
+
290
+
291
+ class Qwen3DecoderLayer(GradientCheckpointingLayer):
292
+ def __init__(self, config: Qwen3Config, layer_idx: int):
293
+ super().__init__()
294
+ self.hidden_size = config.hidden_size
295
+ self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
296
+ self.mlp = Qwen3MLP(config)
297
+ self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
298
+ self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
299
+ if (
300
+ config.sliding_window and config._attn_implementation != "flash_attention_2"
301
+ ): # diff with Llama is this warning
302
+ logger.warning_once(
303
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
304
+ "unexpected results may be encountered."
305
+ )
306
+
307
+ def forward(
308
+ self,
309
+ hidden_states: torch.Tensor,
310
+ attention_mask: Optional[torch.Tensor] = None,
311
+ position_ids: Optional[torch.LongTensor] = None,
312
+ past_key_value: Optional[Cache] = None,
313
+ output_attentions: Optional[bool] = False,
314
+ use_cache: Optional[bool] = False,
315
+ cache_position: Optional[torch.LongTensor] = None,
316
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
317
+ **kwargs: Unpack[FlashAttentionKwargs],
318
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
319
+ residual = hidden_states
320
+ hidden_states = self.input_layernorm(hidden_states)
321
+
322
+ # Self Attention
323
+ hidden_states, self_attn_weights = self.self_attn(
324
+ hidden_states=hidden_states,
325
+ attention_mask=attention_mask,
326
+ position_ids=position_ids,
327
+ past_key_value=past_key_value,
328
+ output_attentions=output_attentions,
329
+ use_cache=use_cache,
330
+ cache_position=cache_position,
331
+ position_embeddings=position_embeddings,
332
+ **kwargs,
333
+ )
334
+ hidden_states = residual + hidden_states
335
+
336
+ # Fully Connected
337
+ residual = hidden_states
338
+ hidden_states = self.post_attention_layernorm(hidden_states)
339
+ hidden_states = self.mlp(hidden_states)
340
+ hidden_states = residual + hidden_states
341
+
342
+ outputs = (hidden_states,)
343
+ if output_attentions:
344
+ outputs += (self_attn_weights,)
345
+
346
+ return outputs
347
+
348
+
349
+ @auto_docstring
350
+ class Qwen3PreTrainedModel(PreTrainedModel):
351
+ config_class = Qwen3Config
352
+ base_model_prefix = "model"
353
+ supports_gradient_checkpointing = True
354
+ _no_split_modules = ["Qwen3DecoderLayer"]
355
+ _skip_keys_device_placement = ["past_key_values"]
356
+ _supports_flash_attn_2 = True
357
+ _supports_sdpa = True
358
+ _supports_flex_attn = True
359
+ _supports_cache_class = True
360
+ _supports_quantized_cache = True
361
+ _supports_static_cache = True
362
+ _supports_attention_backend = True
363
+
364
+ def _init_weights(self, module):
365
+ std = self.config.initializer_range
366
+ if isinstance(module, nn.Linear):
367
+ module.weight.data.normal_(mean=0.0, std=std)
368
+ if module.bias is not None:
369
+ module.bias.data.zero_()
370
+ elif isinstance(module, nn.Embedding):
371
+ module.weight.data.normal_(mean=0.0, std=std)
372
+ if module.padding_idx is not None:
373
+ module.weight.data[module.padding_idx].zero_()
374
+ elif isinstance(module, Qwen3RMSNorm):
375
+ module.weight.data.fill_(1.0)
376
+
377
+
378
+ class Qwen3RotaryEmbedding(nn.Module):
379
+ def __init__(self, config: Qwen3Config, device=None):
380
+ super().__init__()
381
+ # BC: "rope_type" was originally "type"
382
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
383
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
384
+ else:
385
+ self.rope_type = "default"
386
+ self.max_seq_len_cached = config.max_position_embeddings
387
+ self.original_max_seq_len = config.max_position_embeddings
388
+
389
+ self.config = config
390
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
391
+
392
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
393
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
394
+ self.original_inv_freq = self.inv_freq
395
+
396
+ @torch.no_grad()
397
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
398
+ def forward(self, x, position_ids):
399
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
400
+ position_ids_expanded = position_ids[:, None, :].float()
401
+
402
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
403
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
404
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
405
+ emb = torch.cat((freqs, freqs), dim=-1)
406
+ cos = emb.cos() * self.attention_scaling
407
+ sin = emb.sin() * self.attention_scaling
408
+
409
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
410
+
411
+
412
+ @auto_docstring
413
+ class Qwen3Model(Qwen3PreTrainedModel):
414
+ def __init__(self, config: Qwen3Config):
415
+ super().__init__(config)
416
+ self.padding_idx = config.pad_token_id
417
+ self.vocab_size = config.vocab_size
418
+
419
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
420
+ self.layers = nn.ModuleList(
421
+ [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
422
+ )
423
+ self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
424
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config)
425
+ self.gradient_checkpointing = False
426
+
427
+ # Initialize weights and apply final processing
428
+ self.post_init()
429
+
430
+ def get_input_embeddings(self):
431
+ return self.embed_tokens
432
+
433
+ def set_input_embeddings(self, value):
434
+ self.embed_tokens = value
435
+
436
+ @can_return_tuple
437
+ @auto_docstring
438
+ def forward(
439
+ self,
440
+ input_ids: Optional[torch.LongTensor] = None,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ position_ids: Optional[torch.LongTensor] = None,
443
+ past_key_values: Optional[Cache] = None,
444
+ inputs_embeds: Optional[torch.FloatTensor] = None,
445
+ use_cache: Optional[bool] = None,
446
+ output_attentions: Optional[bool] = None,
447
+ output_hidden_states: Optional[bool] = None,
448
+ cache_position: Optional[torch.LongTensor] = None,
449
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
450
+ ) -> BaseModelOutputWithPast:
451
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
452
+ output_hidden_states = (
453
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
454
+ )
455
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
456
+
457
+ if (input_ids is None) ^ (inputs_embeds is not None):
458
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
459
+
460
+ if self.gradient_checkpointing and self.training and use_cache:
461
+ logger.warning_once(
462
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
463
+ )
464
+ use_cache = False
465
+
466
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
467
+ if not isinstance(past_key_values, (type(None), Cache)):
468
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
469
+
470
+ if inputs_embeds is None:
471
+ inputs_embeds = self.embed_tokens(input_ids)
472
+
473
+ if use_cache and past_key_values is None:
474
+ past_key_values = DynamicCache()
475
+
476
+ if cache_position is None:
477
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
478
+ cache_position = torch.arange(
479
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
480
+ )
481
+
482
+ if position_ids is None:
483
+ position_ids = cache_position.unsqueeze(0)
484
+
485
+ causal_mask = self._update_causal_mask(
486
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
487
+ )
488
+
489
+ hidden_states = inputs_embeds
490
+
491
+ # create position embeddings to be shared across the decoder layers
492
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
493
+
494
+ # decoder layers
495
+ all_hidden_states = () if output_hidden_states else None
496
+ all_self_attns = () if output_attentions else None
497
+
498
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
499
+ if output_hidden_states:
500
+ all_hidden_states += (hidden_states,)
501
+
502
+ layer_outputs = decoder_layer(
503
+ hidden_states,
504
+ attention_mask=causal_mask,
505
+ position_ids=position_ids,
506
+ past_key_value=past_key_values,
507
+ output_attentions=output_attentions,
508
+ use_cache=use_cache,
509
+ cache_position=cache_position,
510
+ position_embeddings=position_embeddings,
511
+ **flash_attn_kwargs,
512
+ )
513
+
514
+ hidden_states = layer_outputs[0]
515
+
516
+ if output_attentions:
517
+ all_self_attns += (layer_outputs[1],)
518
+
519
+ hidden_states = self.norm(hidden_states)
520
+
521
+ # add hidden states from the last decoder layer
522
+ if output_hidden_states:
523
+ all_hidden_states += (hidden_states,)
524
+
525
+ return BaseModelOutputWithPast(
526
+ last_hidden_state=hidden_states,
527
+ past_key_values=past_key_values if use_cache else None,
528
+ hidden_states=all_hidden_states,
529
+ attentions=all_self_attns,
530
+ )
531
+
532
+ def _update_causal_mask(
533
+ self,
534
+ attention_mask: Union[torch.Tensor, "BlockMask"],
535
+ input_tensor: torch.Tensor,
536
+ cache_position: torch.Tensor,
537
+ past_key_values: Cache,
538
+ output_attentions: bool = False,
539
+ ):
540
+ if self.config._attn_implementation == "flash_attention_2":
541
+ if attention_mask is not None and past_key_values is not None:
542
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
543
+ if is_padding_right:
544
+ raise ValueError(
545
+ "You are attempting to perform batched generation with padding_side='right'"
546
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
547
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
548
+ )
549
+ if attention_mask is not None and 0.0 in attention_mask:
550
+ return attention_mask
551
+ return None
552
+ if self.config._attn_implementation == "flex_attention":
553
+ # Use flex block mask directly
554
+ assert isinstance(attention_mask, BlockMask)
555
+ return attention_mask
556
+
557
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
558
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
559
+ # to infer the attention mask.
560
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
561
+ using_static_cache = isinstance(past_key_values, StaticCache)
562
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
563
+
564
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
565
+ if (
566
+ self.config._attn_implementation == "sdpa"
567
+ and not (using_static_cache or using_sliding_window_cache)
568
+ and not output_attentions
569
+ ):
570
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
571
+ attention_mask,
572
+ inputs_embeds=input_tensor,
573
+ past_key_values_length=past_seen_tokens,
574
+ sliding_window=self.config.sliding_window,
575
+ is_training=self.training,
576
+ ):
577
+ return None
578
+
579
+ dtype = input_tensor.dtype
580
+ min_dtype = torch.finfo(dtype).min
581
+ sequence_length = input_tensor.shape[1]
582
+ # SlidingWindowCache or StaticCache
583
+ if using_sliding_window_cache or using_static_cache:
584
+ target_length = past_key_values.get_max_cache_shape()
585
+ # DynamicCache or no cache
586
+ else:
587
+ target_length = (
588
+ attention_mask.shape[-1]
589
+ if isinstance(attention_mask, torch.Tensor)
590
+ else past_seen_tokens + sequence_length + 1
591
+ )
592
+
593
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
594
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
595
+ attention_mask,
596
+ sequence_length=sequence_length,
597
+ target_length=target_length,
598
+ dtype=dtype,
599
+ cache_position=cache_position,
600
+ batch_size=input_tensor.shape[0],
601
+ config=self.config,
602
+ past_key_values=past_key_values,
603
+ )
604
+
605
+ if (
606
+ self.config._attn_implementation == "sdpa"
607
+ and attention_mask is not None
608
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
609
+ and not output_attentions
610
+ ):
611
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
612
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
613
+ # Details: https://github.com/pytorch/pytorch/issues/110213
614
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
615
+
616
+ return causal_mask
617
+
618
+ @staticmethod
619
+ def _prepare_4d_causal_attention_mask_with_cache_position(
620
+ attention_mask: torch.Tensor,
621
+ sequence_length: int,
622
+ target_length: int,
623
+ dtype: torch.dtype,
624
+ cache_position: torch.Tensor,
625
+ batch_size: int,
626
+ config: Qwen3Config,
627
+ past_key_values: Cache,
628
+ ):
629
+ """
630
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
631
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
632
+
633
+ Args:
634
+ attention_mask (`torch.Tensor`):
635
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
636
+ sequence_length (`int`):
637
+ The sequence length being processed.
638
+ target_length (`int`):
639
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
640
+ dtype (`torch.dtype`):
641
+ The dtype to use for the 4D attention mask.
642
+ cache_position (`torch.Tensor`):
643
+ Indices depicting the position of the input sequence tokens in the sequence.
644
+ batch_size (`torch.Tensor`):
645
+ Batch size.
646
+ config (`Qwen3Config`):
647
+ The model's configuration class
648
+ past_key_values (`Cache`):
649
+ The cache class that is being used currently to generate
650
+ """
651
+ if attention_mask is not None and attention_mask.dim() == 4:
652
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
653
+ causal_mask = attention_mask
654
+ else:
655
+ min_dtype = torch.finfo(dtype).min
656
+ causal_mask = torch.full(
657
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
658
+ )
659
+ diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
660
+ -1, 1
661
+ )
662
+ text_config = config.get_text_config()
663
+ if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
664
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
665
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
666
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
667
+ sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
668
+ cache_position.reshape(-1, 1) - text_config.sliding_window
669
+ )
670
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
671
+ causal_mask *= diagonal_attend_mask
672
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
673
+ if attention_mask is not None:
674
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
675
+ if attention_mask.shape[-1] > target_length:
676
+ attention_mask = attention_mask[:, :target_length]
677
+ mask_length = attention_mask.shape[-1]
678
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
679
+ causal_mask.device
680
+ )
681
+ padding_mask = padding_mask == 0
682
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
683
+ padding_mask, min_dtype
684
+ )
685
+ return causal_mask
686
+
687
+
688
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
689
+
690
+
691
+ @auto_docstring
692
+ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
693
+ _tied_weights_keys = ["lm_head.weight"]
694
+ _tp_plan = {"lm_head": "colwise_rep"}
695
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
696
+
697
+ def __init__(self, config):
698
+ super().__init__(config)
699
+ self.model = Qwen3Model(config)
700
+ self.vocab_size = config.vocab_size
701
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
702
+
703
+ # Initialize weights and apply final processing
704
+ self.post_init()
705
+
706
+ def get_input_embeddings(self):
707
+ return self.model.embed_tokens
708
+
709
+ def set_input_embeddings(self, value):
710
+ self.model.embed_tokens = value
711
+
712
+ def get_output_embeddings(self):
713
+ return self.lm_head
714
+
715
+ def set_output_embeddings(self, new_embeddings):
716
+ self.lm_head = new_embeddings
717
+
718
+ def set_decoder(self, decoder):
719
+ self.model = decoder
720
+
721
+ def get_decoder(self):
722
+ return self.model
723
+
724
+ @can_return_tuple
725
+ @auto_docstring
726
+ def forward(
727
+ self,
728
+ input_ids: Optional[torch.LongTensor] = None,
729
+ attention_mask: Optional[torch.Tensor] = None,
730
+ position_ids: Optional[torch.LongTensor] = None,
731
+ past_key_values: Optional[Cache] = None,
732
+ inputs_embeds: Optional[torch.FloatTensor] = None,
733
+ labels: Optional[torch.LongTensor] = None,
734
+ use_cache: Optional[bool] = None,
735
+ output_attentions: Optional[bool] = None,
736
+ output_hidden_states: Optional[bool] = None,
737
+ cache_position: Optional[torch.LongTensor] = None,
738
+ logits_to_keep: Union[int, torch.Tensor] = 0,
739
+ **kwargs: Unpack[KwargsForCausalLM],
740
+ ) -> CausalLMOutputWithPast:
741
+ r"""
742
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
743
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
744
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
745
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
746
+
747
+ Example:
748
+
749
+ ```python
750
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
751
+
752
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
753
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
754
+
755
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
756
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
757
+
758
+ >>> # Generate
759
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
760
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
761
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
762
+ ```"""
763
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
764
+ output_hidden_states = (
765
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
766
+ )
767
+
768
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
769
+ outputs: BaseModelOutputWithPast = self.model(
770
+ input_ids=input_ids,
771
+ attention_mask=attention_mask,
772
+ position_ids=position_ids,
773
+ past_key_values=past_key_values,
774
+ inputs_embeds=inputs_embeds,
775
+ use_cache=use_cache,
776
+ output_attentions=output_attentions,
777
+ output_hidden_states=output_hidden_states,
778
+ cache_position=cache_position,
779
+ **kwargs,
780
+ )
781
+
782
+ hidden_states = outputs.last_hidden_state
783
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
784
+ logits_to_keep = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
785
+ hidden_states = hidden_states[:, logits_to_keep, :].contiguous()
786
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
787
+ if fuse_linear_and_cross_entropy:
788
+ logits = None
789
+ else:
790
+ logits = self.lm_head(hidden_states)
791
+
792
+ loss = None
793
+ if labels is not None:
794
+ if fuse_linear_and_cross_entropy:
795
+ loss_fct = FusedLinearDiffusionCrossEntropyLoss(
796
+ reduction='sum')
797
+ else:
798
+ loss_fct = CrossEntropyLoss() # nn.CE
799
+
800
+ # you don't have to shift labels
801
+ # labels = labels.to(hidden_states.device)
802
+ # labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
803
+ if fuse_linear_and_cross_entropy:
804
+ loss = loss_fct( # it will return (sum_loss, unreduced_loss)
805
+ x=hidden_states, # conduct `view(-1, V)` inside the function
806
+ target=labels,
807
+ weight=self.lm_head.weight,
808
+ bias=self.lm_head.bias,
809
+ p_mask=kwargs['p_mask'],
810
+ )
811
+ else:
812
+ loss = loss_fct(
813
+ logits.view(-1, self.config.vocab_size), labels.view(-1))
814
+
815
+ return CausalLMOutputWithPast(
816
+ loss=loss,
817
+ logits=logits,
818
+ past_key_values=outputs.past_key_values,
819
+ hidden_states=outputs.hidden_states,
820
+ attentions=outputs.attentions,
821
+ )
822
+
823
+
824
+ @auto_docstring(
825
+ custom_intro="""
826
+ The Qwen3 Model transformer with a sequence classification head on top (linear layer).
827
+
828
+ [`Qwen3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
829
+ (e.g. GPT-2) do.
830
+
831
+ Since it does classification on the last token, it requires to know the position of the last token. If a
832
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
833
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
834
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
835
+ each row of the batch).
836
+ """
837
+ )
838
+ class Qwen3ForSequenceClassification(Qwen3PreTrainedModel):
839
+ def __init__(self, config):
840
+ super().__init__(config)
841
+ self.num_labels = config.num_labels
842
+ self.model = Qwen3Model(config)
843
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
844
+
845
+ # Initialize weights and apply final processing
846
+ self.post_init()
847
+
848
+ def get_input_embeddings(self):
849
+ return self.model.embed_tokens
850
+
851
+ def set_input_embeddings(self, value):
852
+ self.model.embed_tokens = value
853
+
854
+ @can_return_tuple
855
+ @auto_docstring
856
+ def forward(
857
+ self,
858
+ input_ids: Optional[torch.LongTensor] = None,
859
+ attention_mask: Optional[torch.Tensor] = None,
860
+ position_ids: Optional[torch.LongTensor] = None,
861
+ past_key_values: Optional[Cache] = None,
862
+ inputs_embeds: Optional[torch.FloatTensor] = None,
863
+ labels: Optional[torch.LongTensor] = None,
864
+ use_cache: Optional[bool] = None,
865
+ output_attentions: Optional[bool] = None,
866
+ output_hidden_states: Optional[bool] = None,
867
+ ) -> SequenceClassifierOutputWithPast:
868
+ r"""
869
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
870
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
871
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
872
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
873
+ """
874
+
875
+ transformer_outputs: BaseModelOutputWithPast = self.model(
876
+ input_ids,
877
+ attention_mask=attention_mask,
878
+ position_ids=position_ids,
879
+ past_key_values=past_key_values,
880
+ inputs_embeds=inputs_embeds,
881
+ use_cache=use_cache,
882
+ output_attentions=output_attentions,
883
+ output_hidden_states=output_hidden_states,
884
+ )
885
+ hidden_states = transformer_outputs.last_hidden_state
886
+ logits = self.score(hidden_states)
887
+
888
+ if input_ids is not None:
889
+ batch_size = input_ids.shape[0]
890
+ else:
891
+ batch_size = inputs_embeds.shape[0]
892
+
893
+ if self.config.pad_token_id is None and batch_size != 1:
894
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
895
+ if self.config.pad_token_id is None:
896
+ last_non_pad_token = -1
897
+ elif input_ids is not None:
898
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
899
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
900
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
901
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
902
+ else:
903
+ last_non_pad_token = -1
904
+ logger.warning_once(
905
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
906
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
907
+ )
908
+
909
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
910
+
911
+ loss = None
912
+ if labels is not None:
913
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
914
+
915
+ return SequenceClassifierOutputWithPast(
916
+ loss=loss,
917
+ logits=pooled_logits,
918
+ past_key_values=transformer_outputs.past_key_values,
919
+ hidden_states=transformer_outputs.hidden_states,
920
+ attentions=transformer_outputs.attentions,
921
+ )
922
+
923
+
924
+ @auto_docstring
925
+ class Qwen3ForTokenClassification(Qwen3PreTrainedModel):
926
+ def __init__(self, config):
927
+ super().__init__(config)
928
+ self.num_labels = config.num_labels
929
+ self.model = Qwen3Model(config)
930
+ if getattr(config, "classifier_dropout", None) is not None:
931
+ classifier_dropout = config.classifier_dropout
932
+ elif getattr(config, "hidden_dropout", None) is not None:
933
+ classifier_dropout = config.hidden_dropout
934
+ else:
935
+ classifier_dropout = 0.1
936
+ self.dropout = nn.Dropout(classifier_dropout)
937
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
938
+
939
+ # Initialize weights and apply final processing
940
+ self.post_init()
941
+
942
+ def get_input_embeddings(self):
943
+ return self.model.embed_tokens
944
+
945
+ def set_input_embeddings(self, value):
946
+ self.model.embed_tokens = value
947
+
948
+ @can_return_tuple
949
+ @auto_docstring
950
+ def forward(
951
+ self,
952
+ input_ids: Optional[torch.LongTensor] = None,
953
+ attention_mask: Optional[torch.Tensor] = None,
954
+ position_ids: Optional[torch.LongTensor] = None,
955
+ past_key_values: Optional[Cache] = None,
956
+ inputs_embeds: Optional[torch.FloatTensor] = None,
957
+ labels: Optional[torch.LongTensor] = None,
958
+ use_cache: Optional[bool] = None,
959
+ output_attentions: Optional[bool] = None,
960
+ output_hidden_states: Optional[bool] = None,
961
+ ) -> TokenClassifierOutput:
962
+ r"""
963
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
964
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
965
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
966
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
967
+ """
968
+
969
+ outputs: BaseModelOutputWithPast = self.model(
970
+ input_ids,
971
+ attention_mask=attention_mask,
972
+ position_ids=position_ids,
973
+ past_key_values=past_key_values,
974
+ inputs_embeds=inputs_embeds,
975
+ use_cache=use_cache,
976
+ output_attentions=output_attentions,
977
+ output_hidden_states=output_hidden_states,
978
+ )
979
+ sequence_output = outputs.last_hidden_state
980
+ sequence_output = self.dropout(sequence_output)
981
+ logits = self.score(sequence_output)
982
+
983
+ loss = None
984
+ if labels is not None:
985
+ loss = self.loss_function(logits, labels, self.config)
986
+
987
+ return TokenClassifierOutput(
988
+ loss=loss,
989
+ logits=logits,
990
+ hidden_states=outputs.hidden_states,
991
+ attentions=outputs.attentions,
992
+ )
993
+
994
+
995
+ @auto_docstring
996
+ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel):
997
+ base_model_prefix = "transformer"
998
+
999
+ def __init__(self, config):
1000
+ super().__init__(config)
1001
+ self.transformer = Qwen3Model(config)
1002
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1003
+
1004
+ # Initialize weights and apply final processing
1005
+ self.post_init()
1006
+
1007
+ def get_input_embeddings(self):
1008
+ return self.transformer.embed_tokens
1009
+
1010
+ def set_input_embeddings(self, value):
1011
+ self.transformer.embed_tokens = value
1012
+
1013
+ @can_return_tuple
1014
+ @auto_docstring
1015
+ def forward(
1016
+ self,
1017
+ input_ids: Optional[torch.LongTensor] = None,
1018
+ attention_mask: Optional[torch.Tensor] = None,
1019
+ position_ids: Optional[torch.LongTensor] = None,
1020
+ past_key_values: Optional[Cache] = None,
1021
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1022
+ start_positions: Optional[torch.LongTensor] = None,
1023
+ end_positions: Optional[torch.LongTensor] = None,
1024
+ output_attentions: Optional[bool] = None,
1025
+ output_hidden_states: Optional[bool] = None,
1026
+ **kwargs,
1027
+ ) -> QuestionAnsweringModelOutput:
1028
+ outputs: BaseModelOutputWithPast = self.transformer(
1029
+ input_ids,
1030
+ attention_mask=attention_mask,
1031
+ position_ids=position_ids,
1032
+ past_key_values=past_key_values,
1033
+ inputs_embeds=inputs_embeds,
1034
+ output_attentions=output_attentions,
1035
+ output_hidden_states=output_hidden_states,
1036
+ )
1037
+
1038
+ sequence_output = outputs.last_hidden_state
1039
+
1040
+ logits = self.qa_outputs(sequence_output)
1041
+ start_logits, end_logits = logits.split(1, dim=-1)
1042
+ start_logits = start_logits.squeeze(-1).contiguous()
1043
+ end_logits = end_logits.squeeze(-1).contiguous()
1044
+
1045
+ loss = None
1046
+ if start_positions is not None and end_positions is not None:
1047
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1048
+
1049
+ return QuestionAnsweringModelOutput(
1050
+ loss=loss,
1051
+ start_logits=start_logits,
1052
+ end_logits=end_logits,
1053
+ hidden_states=outputs.hidden_states,
1054
+ attentions=outputs.attentions,
1055
+ )
1056
+
1057
+
1058
+ __all__ = [
1059
+ "Qwen3ForCausalLM",
1060
+ "Qwen3ForQuestionAnswering",
1061
+ "Qwen3Model",
1062
+ "Qwen3PreTrainedModel",
1063
+ "Qwen3ForSequenceClassification",
1064
+ "Qwen3ForTokenClassification",
1065
+ ]
special_tokens_map.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>",
16
+ "<MASK>"
17
+ ],
18
+ "eos_token": {
19
+ "content": "<|endoftext|>",
20
+ "lstrip": false,
21
+ "normalized": false,
22
+ "rstrip": false,
23
+ "single_word": false
24
+ },
25
+ "pad_token": {
26
+ "content": "<|endoftext|>",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false
31
+ }
32
+ }
tokenization_qwen2.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Qwen2."""
16
+
17
+ import json
18
+ import os
19
+ import unicodedata
20
+ from functools import lru_cache
21
+ from typing import Optional, Tuple
22
+
23
+ import regex as re
24
+
25
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ VOCAB_FILES_NAMES = {
32
+ "vocab_file": "vocab.json",
33
+ "merges_file": "merges.txt",
34
+ }
35
+
36
+
37
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
38
+
39
+ PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
40
+
41
+
42
+ @lru_cache()
43
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
44
+ def bytes_to_unicode():
45
+ """
46
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
47
+ characters the bpe code barfs on.
48
+
49
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
50
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
51
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
52
+ tables between utf-8 bytes and unicode strings.
53
+ """
54
+ bs = (
55
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
56
+ )
57
+ cs = bs[:]
58
+ n = 0
59
+ for b in range(2**8):
60
+ if b not in bs:
61
+ bs.append(b)
62
+ cs.append(2**8 + n)
63
+ n += 1
64
+ cs = [chr(n) for n in cs]
65
+ return dict(zip(bs, cs))
66
+
67
+
68
+ # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
69
+ def get_pairs(word):
70
+ """
71
+ Return set of symbol pairs in a word.
72
+
73
+ Word is represented as tuple of symbols (symbols being variable-length strings).
74
+ """
75
+ pairs = set()
76
+ prev_char = word[0]
77
+ for char in word[1:]:
78
+ pairs.add((prev_char, char))
79
+ prev_char = char
80
+ return pairs
81
+
82
+
83
+ class Qwen2Tokenizer(PreTrainedTokenizer):
84
+ """
85
+ Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
86
+
87
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
88
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
89
+
90
+ ```python
91
+ >>> from transformers import Qwen2Tokenizer
92
+
93
+ >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
94
+ >>> tokenizer("Hello world")["input_ids"]
95
+ [9707, 1879]
96
+
97
+ >>> tokenizer(" Hello world")["input_ids"]
98
+ [21927, 1879]
99
+ ```
100
+ This is expected.
101
+
102
+ You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
103
+
104
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
105
+ this superclass for more information regarding those methods.
106
+
107
+ Args:
108
+ vocab_file (`str`):
109
+ Path to the vocabulary file.
110
+ merges_file (`str`):
111
+ Path to the merges file.
112
+ errors (`str`, *optional*, defaults to `"replace"`):
113
+ Paradigm to follow when decoding bytes to UTF-8. See
114
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
115
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
116
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
117
+ token instead.
118
+ bos_token (`str`, *optional*):
119
+ The beginning of sequence token. Not applicable for this tokenizer.
120
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
121
+ The end of sequence token.
122
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
123
+ The token used for padding, for example when batching sequences of different lengths.
124
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
125
+ Whether or not the model should cleanup the spaces that were added when splitting the input text during the
126
+ tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
127
+ split_special_tokens (`bool`, *optional*, defaults to `False`):
128
+ Whether or not the special tokens should be split during the tokenization process. The default behavior is
129
+ to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
130
+ ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
131
+ '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
132
+ """
133
+
134
+ vocab_files_names = VOCAB_FILES_NAMES
135
+ model_input_names = ["input_ids", "attention_mask"]
136
+
137
+ def __init__(
138
+ self,
139
+ vocab_file,
140
+ merges_file,
141
+ errors="replace",
142
+ unk_token="<|endoftext|>",
143
+ bos_token=None,
144
+ eos_token="<|endoftext|>",
145
+ pad_token="<|endoftext|>",
146
+ clean_up_tokenization_spaces=False,
147
+ split_special_tokens=False,
148
+ **kwargs,
149
+ ):
150
+ # Qwen vocab does not contain control tokens; added tokens need to be special
151
+ bos_token = (
152
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
153
+ if isinstance(bos_token, str)
154
+ else bos_token
155
+ )
156
+ eos_token = (
157
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
158
+ if isinstance(eos_token, str)
159
+ else eos_token
160
+ )
161
+ unk_token = (
162
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
163
+ if isinstance(unk_token, str)
164
+ else unk_token
165
+ )
166
+ pad_token = (
167
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
168
+ if isinstance(pad_token, str)
169
+ else pad_token
170
+ )
171
+
172
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
173
+ self.encoder = json.load(vocab_handle)
174
+ self.decoder = {v: k for k, v in self.encoder.items()}
175
+ self.errors = errors # how to handle errors in decoding
176
+ self.byte_encoder = bytes_to_unicode()
177
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
178
+ bpe_merges = []
179
+ with open(merges_file, encoding="utf-8") as merges_handle:
180
+ for i, line in enumerate(merges_handle):
181
+ line = line.strip()
182
+ if (i == 0 and line.startswith("#version:")) or not line:
183
+ continue
184
+ bpe_merges.append(tuple(line.split()))
185
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
186
+ # NOTE: the cache can grow without bound and will get really large for long running processes
187
+ # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
188
+ # not a memory leak but appears as one.
189
+ # GPT2Tokenizer has the same problem, so let's be consistent.
190
+ self.cache = {}
191
+
192
+ self.pat = re.compile(PRETOKENIZE_REGEX)
193
+
194
+ if kwargs.get("add_prefix_space", False):
195
+ logger.warning_once(
196
+ f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
197
+ )
198
+
199
+ super().__init__(
200
+ errors=errors,
201
+ bos_token=bos_token,
202
+ eos_token=eos_token,
203
+ pad_token=pad_token,
204
+ unk_token=unk_token,
205
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
206
+ split_special_tokens=split_special_tokens,
207
+ **kwargs,
208
+ )
209
+
210
+ @property
211
+ def vocab_size(self) -> int:
212
+ return len(self.encoder)
213
+
214
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
215
+ def get_vocab(self):
216
+ return dict(self.encoder, **self.added_tokens_encoder)
217
+
218
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
219
+ def bpe(self, token):
220
+ if token in self.cache:
221
+ return self.cache[token]
222
+ word = tuple(token)
223
+ pairs = get_pairs(word)
224
+
225
+ if not pairs:
226
+ return token
227
+
228
+ while True:
229
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
230
+ if bigram not in self.bpe_ranks:
231
+ break
232
+ first, second = bigram
233
+ new_word = []
234
+ i = 0
235
+ while i < len(word):
236
+ try:
237
+ j = word.index(first, i)
238
+ except ValueError:
239
+ new_word.extend(word[i:])
240
+ break
241
+ else:
242
+ new_word.extend(word[i:j])
243
+ i = j
244
+
245
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
246
+ new_word.append(first + second)
247
+ i += 2
248
+ else:
249
+ new_word.append(word[i])
250
+ i += 1
251
+ new_word = tuple(new_word)
252
+ word = new_word
253
+ if len(word) == 1:
254
+ break
255
+ else:
256
+ pairs = get_pairs(word)
257
+ word = " ".join(word)
258
+ self.cache[token] = word
259
+ return word
260
+
261
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
262
+ def _tokenize(self, text):
263
+ """Tokenize a string."""
264
+ bpe_tokens = []
265
+ for token in re.findall(self.pat, text):
266
+ token = "".join(
267
+ self.byte_encoder[b] for b in token.encode("utf-8")
268
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
269
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
270
+ return bpe_tokens
271
+
272
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
273
+ def _convert_token_to_id(self, token):
274
+ """Converts a token (str) in an id using the vocab."""
275
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
276
+
277
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
278
+ def _convert_id_to_token(self, index):
279
+ """Converts an index (integer) in a token (str) using the vocab."""
280
+ return self.decoder.get(index)
281
+
282
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
283
+ def convert_tokens_to_string(self, tokens):
284
+ """Converts a sequence of tokens (string) in a single string."""
285
+ text = "".join(tokens)
286
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
287
+ return text
288
+
289
+ def decode(
290
+ self,
291
+ token_ids,
292
+ skip_special_tokens: bool = False,
293
+ clean_up_tokenization_spaces: Optional[bool] = False,
294
+ spaces_between_special_tokens: bool = False,
295
+ **kwargs,
296
+ ) -> str:
297
+ # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
298
+ # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
299
+ return super().decode(
300
+ token_ids,
301
+ skip_special_tokens=skip_special_tokens,
302
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
303
+ spaces_between_special_tokens=spaces_between_special_tokens,
304
+ **kwargs,
305
+ )
306
+
307
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
308
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
309
+ if not os.path.isdir(save_directory):
310
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
311
+ return
312
+ vocab_file = os.path.join(
313
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
314
+ )
315
+ merge_file = os.path.join(
316
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
317
+ )
318
+
319
+ with open(vocab_file, "w", encoding="utf-8") as f:
320
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
321
+
322
+ index = 0
323
+ with open(merge_file, "w", encoding="utf-8") as writer:
324
+ writer.write("#version: 0.2\n")
325
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
326
+ if index != token_index:
327
+ logger.warning(
328
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
329
+ " Please check that the tokenizer is not corrupted!"
330
+ )
331
+ index = token_index
332
+ writer.write(" ".join(bpe_tokens) + "\n")
333
+ index += 1
334
+
335
+ return vocab_file, merge_file
336
+
337
+ def prepare_for_tokenization(self, text, **kwargs):
338
+ text = unicodedata.normalize("NFC", text)
339
+ return (text, kwargs)
340
+
341
+
342
+ __all__ = ["Qwen2Tokenizer"]
tokenization_qwen2_fast.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Qwen2."""
16
+
17
+ from typing import Optional, Tuple
18
+
19
+ from transformers.tokenization_utils import AddedToken
20
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
21
+ from transformers.utils import logging
22
+ from .tokenization_qwen2 import Qwen2Tokenizer
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {
28
+ "vocab_file": "vocab.json",
29
+ "merges_file": "merges.txt",
30
+ "tokenizer_file": "tokenizer.json",
31
+ }
32
+
33
+
34
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
35
+
36
+
37
+ class Qwen2TokenizerFast(PreTrainedTokenizerFast):
38
+ """
39
+ Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
40
+ Byte-Pair-Encoding.
41
+
42
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
43
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
44
+
45
+ ```python
46
+ >>> from transformers import Qwen2TokenizerFast
47
+
48
+ >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
49
+ >>> tokenizer("Hello world")["input_ids"]
50
+ [9707, 1879]
51
+
52
+ >>> tokenizer(" Hello world")["input_ids"]
53
+ [21927, 1879]
54
+ ```
55
+ This is expected.
56
+
57
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
58
+ refer to this superclass for more information regarding those methods.
59
+
60
+ Args:
61
+ vocab_file (`str`, *optional*):
62
+ Path to the vocabulary file.
63
+ merges_file (`str`, *optional*):
64
+ Path to the merges file.
65
+ tokenizer_file (`str`, *optional*):
66
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
67
+ contains everything needed to load the tokenizer.
68
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
69
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
70
+ token instead. Not applicable to this tokenizer.
71
+ bos_token (`str`, *optional*):
72
+ The beginning of sequence token. Not applicable for this tokenizer.
73
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
74
+ The end of sequence token.
75
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
76
+ The token used for padding, for example when batching sequences of different lengths.
77
+ """
78
+
79
+ vocab_files_names = VOCAB_FILES_NAMES
80
+ model_input_names = ["input_ids", "attention_mask"]
81
+ slow_tokenizer_class = Qwen2Tokenizer
82
+
83
+ def __init__(
84
+ self,
85
+ vocab_file=None,
86
+ merges_file=None,
87
+ tokenizer_file=None,
88
+ unk_token="<|endoftext|>",
89
+ bos_token=None,
90
+ eos_token="<|endoftext|>",
91
+ pad_token="<|endoftext|>",
92
+ **kwargs,
93
+ ):
94
+ # We need to at least pass vocab_file and merges_file to base class
95
+ # in case a slow tokenizer needs to be initialized; other can be
96
+ # configured through files.
97
+ # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
98
+
99
+ bos_token = (
100
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
101
+ if isinstance(bos_token, str)
102
+ else bos_token
103
+ )
104
+ eos_token = (
105
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
106
+ if isinstance(eos_token, str)
107
+ else eos_token
108
+ )
109
+ unk_token = (
110
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
111
+ if isinstance(unk_token, str)
112
+ else unk_token
113
+ )
114
+ pad_token = (
115
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
116
+ if isinstance(pad_token, str)
117
+ else pad_token
118
+ )
119
+
120
+ super().__init__(
121
+ vocab_file=vocab_file,
122
+ merges_file=merges_file,
123
+ tokenizer_file=tokenizer_file,
124
+ unk_token=unk_token,
125
+ bos_token=bos_token,
126
+ eos_token=eos_token,
127
+ pad_token=pad_token,
128
+ **kwargs,
129
+ )
130
+
131
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
132
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
133
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
134
+ return tuple(files)
135
+
136
+
137
+ __all__ = ["Qwen2TokenizerFast"]
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+
5
+ "added_tokens_decoder": {
6
+ "151643": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "151644": {
15
+ "content": "<|im_start|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "151645": {
23
+ "content": "<|im_end|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "151646": {
31
+ "content": "<|object_ref_start|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "151647": {
39
+ "content": "<|object_ref_end|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "151648": {
47
+ "content": "<|box_start|>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "151649": {
55
+ "content": "<|box_end|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "151650": {
63
+ "content": "<|quad_start|>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "151651": {
71
+ "content": "<|quad_end|>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "151652": {
79
+ "content": "<|vision_start|>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "151653": {
87
+ "content": "<|vision_end|>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ },
94
+ "151654": {
95
+ "content": "<|vision_pad|>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": true
101
+ },
102
+ "151655": {
103
+ "content": "<|image_pad|>",
104
+ "lstrip": false,
105
+ "normalized": false,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": true
109
+ },
110
+ "151656": {
111
+ "content": "<|video_pad|>",
112
+ "lstrip": false,
113
+ "normalized": false,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": true
117
+ },
118
+ "151657": {
119
+ "content": "<tool_call>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "151658": {
127
+ "content": "</tool_call>",
128
+ "lstrip": false,
129
+ "normalized": false,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "151659": {
135
+ "content": "<|fim_prefix|>",
136
+ "lstrip": false,
137
+ "normalized": false,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "151660": {
143
+ "content": "<|fim_middle|>",
144
+ "lstrip": false,
145
+ "normalized": false,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "151661": {
151
+ "content": "<|fim_suffix|>",
152
+ "lstrip": false,
153
+ "normalized": false,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "151662": {
159
+ "content": "<|fim_pad|>",
160
+ "lstrip": false,
161
+ "normalized": false,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "151663": {
167
+ "content": "<|repo_name|>",
168
+ "lstrip": false,
169
+ "normalized": false,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "151664": {
175
+ "content": "<|file_sep|>",
176
+ "lstrip": false,
177
+ "normalized": false,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "151665": {
183
+ "content": "<tool_response>",
184
+ "lstrip": false,
185
+ "normalized": false,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "151666": {
191
+ "content": "</tool_response>",
192
+ "lstrip": false,
193
+ "normalized": false,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "151667": {
199
+ "content": "<think>",
200
+ "lstrip": false,
201
+ "normalized": false,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "151668": {
207
+ "content": "</think>",
208
+ "lstrip": false,
209
+ "normalized": false,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "151669": {
215
+ "content": "<|MASK|>",
216
+ "lstrip": false,
217
+ "normalized": false,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ }
222
+ },
223
+ "additional_special_tokens": [
224
+ "<|im_start|>",
225
+ "<|im_end|>",
226
+ "<|object_ref_start|>",
227
+ "<|object_ref_end|>",
228
+ "<|box_start|>",
229
+ "<|box_end|>",
230
+ "<|quad_start|>",
231
+ "<|quad_end|>",
232
+ "<|vision_start|>",
233
+ "<|vision_end|>",
234
+ "<|vision_pad|>",
235
+ "<|image_pad|>",
236
+ "<|video_pad|>",
237
+ "<|MASK|>"
238
+ ],
239
+ "auto_map": {
240
+ "AutoTokenizer": [
241
+ "tokenization_qwen2.Qwen2Tokenizer",
242
+ null
243
+ ]
244
+ },
245
+ "bos_token": null,
246
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
247
+ "clean_up_tokenization_spaces": false,
248
+ "eos_token": "<|endoftext|>",
249
+ "mask_token": "<|MASK|>",
250
+ "errors": "replace",
251
+ "model_max_length": 131072,
252
+ "pad_token": "<|endoftext|>",
253
+ "split_special_tokens": false,
254
+ "tokenizer_class": "Qwen2Tokenizer",
255
+ "unk_token": null
256
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff