Safetensors
FLMAudio
custom_code
nathanyu commited on
Commit
0e28a9a
·
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLM-Audio
2
+ FLM-Audio is a audio-language subversion of [RoboEgo/FLM-Ego](https://arxiv.org/abs/2506.01934v1) -- an omnimodal model with native full duplexity. It simultaneously listens, speaks, and composes internal monologue, delivering low‑latency, duplex conversational responses in both English and Chinese. FLM‑Audio is robust to noise and user interruptions, prioritizing responsiveness and naturalness.
3
+
4
+ ## 📄 Model Card
5
+
6
+ - **Language(s):** Chinese; English;
7
+
8
+ ## 📚 Technical Report
9
+ Motivation & Survey: [Toward Embodied AGI: A Review of Embodied AI and the Road Ahead](https://arxiv.org/abs/2505.14235)
10
+
11
+ System Card: [RoboEgo System Card: An Omnimodal Model with Native Full Duplexity](https://arxiv.org/abs/2506.01934v1)
12
+
13
+
14
+ ## ⚠️ Bias, Risks, and Limitations
15
+
16
+ Despite extensive data cleaning, FLM‑Audio may still produce undesired content (e.g., biased or offensive language). Users should not disseminate unsafe outputs. Project authors are not responsible for misuse or harmful consequences.
17
+
18
+
19
+ ## 🚀 Quick Start
20
+ Please refer to the repository of [FLM-Audio server](https://github.com/cofe-ai/flm-audio) to interact with FLM-Audio via WebUI.
21
+
22
+ ## ℹ️ Usage Notice
23
+ This project is intended for research use only in compliance with applicable laws. For commercial use, please contact us.
24
+
25
+
26
+ ## 🏗️ Training Details
27
+
28
+ ### Overview
29
+ We initialize the FLM-Audio backbone
30
+ with a pre-trained language model. This initialization strategy significantly reduces computational cost while remaining effective for validating the core concepts of omnimodality and full duplexity. The training process of FLM-Audio consists of two stages: post-training and fine-tuning.
31
+
32
+ #### 1. Post-training
33
+ In post-training, we introduce audio-oriented capabilities to the backbone model using a large-scale corpus of audio data, while preserving the language modeling abilities of the pre-trained foundation model. This stage encompasses a broad spectrum of speech-related tasks, including automatic speech recognition (ASR), text-to-speech synthesis (TTS).
34
+
35
+ #### 2. Supervised Fine-tuning (SFT)
36
+ In this stage, we fine-tune FLM-Audio to function as a general-purpose, full-duplex audio-language chatbot. To this
37
+ end, we primarily utilize synthesized multi-turn, speech dialogues. This dataset is further augmented to support full-duplex
38
+ interruption handling and to enhance robustness against environmental noise.
39
+
40
+ ### Model Architecture
41
+ To handle real-time language and audio, FLM-Audio features an LLM-based backbone with 7B parameters, enhanced by an audio encoder that embeds incoming speech into semantic + acoustic tokens, and a decoder that generates audio tokens. Listening, speaking, and internal monologue are interleaved in synchronized timesteps, with improved stream organization compared to related work (e.g. Moshi).
42
+
43
+
44
+ ## 🧪 Evaluation
45
+
46
+ ### Audio Understanding, Generation
47
+ FLM-Audio performs comparably with strong audio-language models, most of which lacks native duplexity.
48
+
49
+ | Model | ASR-zh↓ | ASR-en↓ | TTS-zh↓ |TTS-en↓ |
50
+ |------------|:-------:|:----------:|:---------:|:---------:|
51
+ | | Fleurs-zh |LibriSpeech-clean | Seed-tts-zh| Seed-tts-en |
52
+ | GPT-4o | 5.4 | - | - | -|
53
+ | MinMo | 3.0 | 1.7| 2.48 | 2.90 |
54
+ | GLM-4-Voice | - |2.8| 2.10 | 2.91 |
55
+ | Moshi | - |5.7| - | - |
56
+ | Qwen-2.5-omni | 3.0 |1.8| 1.70 | 2.72 |
57
+ | FLM-Audio | 5.4 |3.2| 2.10 | 2.95 |
58
+
59
+
60
+ ### Chat
61
+ Regarding chatting experience, FLM-Audio demonstrates advantages in speech naturalness and responsiveness. The following are LLM-scores for audio chatting scenarios like Alpaca-eval, as well as human evaluation in video-grounded omnimodal chatting. The human scores in Naturalness and Responsiveness reflect the contribution of the same audio-oriented training as FLM-Audio.
62
+
63
+ | Model | LLM score↑ | Helpfulness↑ | Naturalness↑| Responsiveness↑| Robustness↑|
64
+ |--------------|:-------:|:------:|:-----:|:-----:|:-----:|
65
+ | Qwen-2.5-omni | 6.36 | 7.4 |7.9 | 8.1| 7.7 |
66
+ | FLM-Audio | 6.58 | 7.2 | 8.2 | 8.8 | 8.0 |
67
+
68
+
69
+ ## 🙏 Acknowledgements
70
+ This work is supported by the National Science and Technology Major Project (No. 2022ZD0116314).
71
+
72
+ ## 🗨️ Citation
73
+ If you find our work helpful, please consider citing the following papers.
74
+ ```
75
+ @article{embodied-agi,
76
+ title={Toward embodied agi: A review of embodied ai and the road ahead},
77
+ author={Wang, Yequan and Sun, Aixin},
78
+ journal={arXiv preprint arXiv:2505.14235},
79
+ year={2025}
80
+ }
81
+ @article{roboego,
82
+ title={RoboEgo System Card: An Omnimodal Model with Native Full Duplexity},
83
+ author={Yao, Yiqun and Li, Xiang and Jiang, Xin and Fang, Xuezhi and Yu, Naitong and Sun, Aixin and Wang, Yequan},
84
+ journal={arXiv preprint arXiv:2506.01934},
85
+ year={2025}
86
+ }
87
+ ```
88
+
89
+
90
+
added_tokens.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|answer|>": 151667,
5
+ "<|asr|>": 151666,
6
+ "<|box_end|>": 151649,
7
+ "<|box_start|>": 151648,
8
+ "<|endoftext|>": 151643,
9
+ "<|file_sep|>": 151664,
10
+ "<|fim_middle|>": 151660,
11
+ "<|fim_pad|>": 151662,
12
+ "<|fim_prefix|>": 151659,
13
+ "<|fim_suffix|>": 151661,
14
+ "<|im_end|>": 151645,
15
+ "<|im_start|>": 151644,
16
+ "<|image_pad|>": 151655,
17
+ "<|object_ref_end|>": 151647,
18
+ "<|object_ref_start|>": 151646,
19
+ "<|quad_end|>": 151651,
20
+ "<|quad_start|>": 151650,
21
+ "<|repo_name|>": 151663,
22
+ "<|text_wait|>": 151665,
23
+ "<|video_pad|>": 151656,
24
+ "<|vision_end|>": 151653,
25
+ "<|vision_pad|>": 151654,
26
+ "<|vision_start|>": 151652
27
+ }
config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "FLMAudioForCausalLM"
4
+ ],
5
+ "attention_bias": true,
6
+ "attention_dropout": 0.0,
7
+ "aud_channel": 8,
8
+ "aud_depthgpt": {
9
+ "bias": false,
10
+ "dropout": 0.0,
11
+ "n_embd": 1024,
12
+ "n_head": 16,
13
+ "n_layer": 6,
14
+ "use_cmlp": true,
15
+ "use_rmsnorm": true,
16
+ "use_swiglu": true
17
+ },
18
+ "aud_vocab_size": 2050,
19
+ "auto_map": {
20
+ "AutoConfig": "configuration_flmaudio.FLMAudioConfig",
21
+ "AutoModel": "modeling_flmaudio.FLMAudioModel",
22
+ "AutoModelForCausalLM": "modeling_flmaudio.FLMAudioForCausalLM"
23
+ },
24
+ "bos_token_id": 151643,
25
+ "disable_att_o_bias": true,
26
+ "eos_token_id": 151645,
27
+ "hidden_act": "silu",
28
+ "hidden_size": 3584,
29
+ "initializer_range": 0.02,
30
+ "input_mult": 1.0,
31
+ "intermediate_size": 18944,
32
+ "max_position_embeddings": 8192,
33
+ "mm_token_info": {
34
+ "aud_emp_token_id": 2049,
35
+ "aud_pad_token_id": 2048,
36
+ "text_wait_token_id": 151665
37
+ },
38
+ "model_type": "FLMAudio",
39
+ "mup_scale_factor": 28.0,
40
+ "num_attention_heads": 28,
41
+ "num_hidden_layers": 28,
42
+ "num_key_value_heads": 4,
43
+ "output_mult": 28.0,
44
+ "pretraining_tp": 1,
45
+ "rms_norm_eps": 1e-05,
46
+ "rope_scaling": {
47
+ "mrope_section": [
48
+ 16,
49
+ 24,
50
+ 24
51
+ ],
52
+ "rope_type": "default",
53
+ "type": "default"
54
+ },
55
+ "rope_theta": 1000000,
56
+ "sliding_window": 32768,
57
+ "tie_word_embeddings": false,
58
+ "torch_dtype": "bfloat16",
59
+ "transformers_version": "4.53.1",
60
+ "use_cache": true,
61
+ "use_mup": true,
62
+ "use_sliding_window": false,
63
+ "vocab_size": 151668
64
+ }
configuration_flmaudio.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """FLM-Audio model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+ from dataclasses import dataclass
25
+ from transformers.modeling_rope_utils import rope_config_validation
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ FLMAUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
30
+
31
+
32
+ @dataclass
33
+ class TokenInfo(dict):
34
+ text_wait_token_id: int
35
+ aud_pad_token_id: int
36
+ aud_emp_token_id: int
37
+
38
+ def __post_init__(self):
39
+ super().__init__(self, **self.__dict__)
40
+
41
+
42
+ @dataclass
43
+ class DepthGPTConfig(dict):
44
+ n_layer: int
45
+ n_head: int
46
+ n_embd: int
47
+ dropout: float
48
+ bias: bool
49
+ use_cmlp: bool
50
+ use_rmsnorm: bool
51
+ use_swiglu: bool
52
+
53
+ def __post_init__(self):
54
+ super().__init__(self, **self.__dict__)
55
+
56
+
57
+ class FLMAudioConfig(PretrainedConfig):
58
+ r"""
59
+ This is the configuration class to store the configuration of a [`FLMAudio`]. It is used to instantiate an FLMAudio
60
+ model according to the specified arguments, defining the model architecture.
61
+
62
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
63
+ documentation from [`PretrainedConfig`] for more information.
64
+
65
+
66
+ Args:
67
+ vocab_size (`int`, *optional*, defaults to 32000):
68
+ Vocabulary size of the TeleFLM model. Defines the number of different tokens that can be represented by the
69
+ `inputs_ids` passed when calling [`TeleFLM`]
70
+ hidden_size (`int`, *optional*, defaults to 4096):
71
+ Dimension of the hidden representations.
72
+ intermediate_size (`int`, *optional*, defaults to 11008):
73
+ Dimension of the MLP representations.
74
+ num_hidden_layers (`int`, *optional*, defaults to 32):
75
+ Number of hidden layers in the Transformer decoder.
76
+ num_attention_heads (`int`, *optional*, defaults to 32):
77
+ Number of attention heads for each attention layer in the Transformer decoder.
78
+ num_key_value_heads (`int`, *optional*):
79
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
80
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
81
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
82
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
83
+ by meanpooling all the original heads within that group. For more details checkout [this
84
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
85
+ `num_attention_heads`.
86
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
87
+ The non-linear activation function (function or string) in the decoder.
88
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
89
+ The maximum sequence length that this model might ever be used with. TeleFLM supports up to 4096 tokens.
90
+ initializer_range (`float`, *optional*, defaults to 0.02):
91
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
92
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
93
+ The epsilon used by the rms normalization layers.
94
+ use_cache (`bool`, *optional*, defaults to `True`):
95
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
96
+ relevant if `config.is_decoder=True`.
97
+ pad_token_id (`int`, *optional*):
98
+ Padding token id.
99
+ bos_token_id (`int`, *optional*, defaults to 1):
100
+ Beginning of stream token id.
101
+ eos_token_id (`int`, *optional*, defaults to 2):
102
+ End of stream token id.
103
+ pretraining_tp (`int`, *optional*, defaults to 1):
104
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
105
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
106
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
107
+ issue](https://github.com/pytorch/pytorch/issues/76232).
108
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
109
+ Whether to tie weight embeddings
110
+ rope_theta (`float`, *optional*, defaults to 10000.0):
111
+ The base period of the RoPE embeddings.
112
+ rope_scaling (`Dict`, *optional*):
113
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
114
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
115
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
116
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
117
+ these scaling strategies behave:
118
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
119
+ experimental feature, subject to breaking API changes in future versions.
120
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
121
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
122
+ attention_dropout (`float`, *optional*, defaults to 0.0):
123
+ The dropout ratio for the attention probabilities.
124
+
125
+ ```python
126
+ >>> from transformers import FLMAudioModel, FLMAudioConfig
127
+
128
+ >>> # Initializing a FLMAudio configuration
129
+ >>> configuration = FLMAudioConfig()
130
+
131
+ >>> # Initializing a model from FLMAudio configuration
132
+ >>> model = FLMAudioModel(configuration)
133
+
134
+ >>> # Accessing the model configuration
135
+ >>> configuration = model.config
136
+ ```"""
137
+
138
+ model_type = "FLMAudio"
139
+ keys_to_ignore_at_inference = ["past_key_values"]
140
+
141
+ def __init__(
142
+ self,
143
+ vocab_size=32000,
144
+ aud_vocab_size=2048,
145
+ aud_channel=8,
146
+ hidden_size=4096,
147
+ intermediate_size=11008,
148
+ num_hidden_layers=32,
149
+ num_attention_heads=32,
150
+ num_key_value_heads=None,
151
+ hidden_act="silu",
152
+ max_position_embeddings=2048,
153
+ initializer_range=0.02,
154
+ rms_norm_eps=1e-6,
155
+ use_cache=True,
156
+ pad_token_id=None,
157
+ bos_token_id=1,
158
+ eos_token_id=2,
159
+ mm_token_info=None,
160
+ aud_depthgpt=None,
161
+ pretraining_tp=1,
162
+ tie_word_embeddings=False,
163
+ rope_theta=10000.0,
164
+ rope_scaling=None,
165
+ attention_bias=False,
166
+ disable_att_o_bias=False,
167
+ attention_dropout=0.0,
168
+ use_mup=False,
169
+ mup_scale_factor=1.0,
170
+ output_mult=1.0,
171
+ input_mult=1.0,
172
+ **kwargs,
173
+ ):
174
+ self.vocab_size = vocab_size
175
+ self.aud_vocab_size = aud_vocab_size
176
+ self.aud_channel = aud_channel
177
+
178
+ self.max_position_embeddings = max_position_embeddings
179
+ self.hidden_size = hidden_size
180
+ self.intermediate_size = intermediate_size
181
+ self.num_hidden_layers = num_hidden_layers
182
+ self.num_attention_heads = num_attention_heads
183
+
184
+ # for backward compatibility
185
+ if num_key_value_heads is None:
186
+ num_key_value_heads = num_attention_heads
187
+
188
+ self.num_key_value_heads = num_key_value_heads
189
+ self.hidden_act = hidden_act
190
+ self.initializer_range = initializer_range
191
+ self.rms_norm_eps = rms_norm_eps
192
+ self.pretraining_tp = pretraining_tp
193
+ self.use_cache = use_cache
194
+ self.rope_theta = rope_theta
195
+ self.rope_scaling = rope_scaling
196
+ self.attention_bias = attention_bias
197
+ self.disable_att_o_bias = disable_att_o_bias
198
+ self.attention_dropout = attention_dropout
199
+ self.use_mup = use_mup
200
+ self.mup_scale_factor = mup_scale_factor
201
+ self.output_mult = output_mult
202
+ self.input_mult = input_mult
203
+
204
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
205
+ if self.rope_scaling["type"] == "mrope":
206
+ self.rope_scaling["type"] = "default"
207
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
208
+ rope_config_validation(self, ignore_keys={"mrope_section"})
209
+
210
+ if mm_token_info is not None:
211
+ self.mm_token_info = TokenInfo(**mm_token_info)
212
+
213
+ if aud_depthgpt is not None:
214
+ self.aud_depthgpt = DepthGPTConfig(**aud_depthgpt)
215
+
216
+ super().__init__(
217
+ pad_token_id=pad_token_id,
218
+ bos_token_id=bos_token_id,
219
+ eos_token_id=eos_token_id,
220
+ tie_word_embeddings=tie_word_embeddings,
221
+ **kwargs,
222
+ )
depth_gpt.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.configuration_utils import PretrainedConfig
7
+
8
+
9
+ class DepthGPTConfig(PretrainedConfig):
10
+ def __init__(
11
+ self,
12
+ block_size: int = 8,
13
+ vocab_size: int = 2049, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
14
+ n_layer: int = 6,
15
+ n_head: int = 16,
16
+ n_embd: int = 1024,
17
+ dropout: float = 0.0,
18
+ bias: bool = False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
19
+ main_hidden_size = 1536,
20
+ pad_token_id = 2048,
21
+ use_cmlp = True,
22
+ use_rmsnorm = False,
23
+ use_swiglu = False
24
+ ):
25
+ """
26
+ {
27
+ "block_size": 8,
28
+ "vocab_size": 2049,
29
+ "n_layer": 6,
30
+ "n_head": 16,
31
+ "n_embd": 1024,
32
+ "dropout": 0.0,
33
+ "bias": false,
34
+ "main_hidden_size": 1536,
35
+ "pad_token_id": 2048,
36
+ "use_cmlp": true
37
+ }
38
+ """
39
+ # super().__init__(**kwargs)
40
+ self.block_size = block_size
41
+ self.vocab_size = vocab_size
42
+ self.n_layer = n_layer
43
+ self.n_head = n_head
44
+ self.n_embd = n_embd
45
+ self.dropout = dropout
46
+ self.bias = bias
47
+ self.main_hidden_size = main_hidden_size
48
+ self.pad_token_id = pad_token_id
49
+ self.use_cmlp = use_cmlp
50
+ self.use_rmsnorm = use_rmsnorm
51
+ self.use_swiglu = use_swiglu
52
+
53
+ ################################################################################################
54
+ # GPT style
55
+ ################################################################################################
56
+
57
+ class LayerNorm(nn.Module):
58
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
59
+
60
+ def __init__(self, ndim, bias):
61
+ super().__init__()
62
+ self.weight = nn.Parameter(torch.ones(ndim))
63
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
64
+
65
+ def forward(self, input):
66
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
67
+
68
+
69
+ class RMSNorm(nn.Module):
70
+ def __init__(self, dim: int, eps: float = 1e-6):
71
+ super(RMSNorm, self).__init__()
72
+ self.eps = eps
73
+ self.weight = nn.Parameter(torch.ones(dim))
74
+
75
+ def _norm(self, x):
76
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
77
+
78
+ def forward(self, x):
79
+ output = self._norm(x.float()).type_as(x)
80
+ return output * self.weight
81
+
82
+
83
+ class CausalSelfAttention(nn.Module):
84
+
85
+ def __init__(self, config):
86
+ super().__init__()
87
+ assert config.n_embd % config.n_head == 0
88
+ # key, query, value projections for all heads, but in a batch
89
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
90
+ # output projection
91
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
92
+ # regularization
93
+ self.attn_dropout = nn.Dropout(config.dropout)
94
+ self.resid_dropout = nn.Dropout(config.dropout)
95
+ self.n_head = config.n_head
96
+ self.n_embd = config.n_embd
97
+ self.dropout = config.dropout
98
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
99
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
100
+ if not self.flash:
101
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
102
+ # causal mask to ensure that attention is only applied to the left in the input sequence
103
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
104
+ .view(1, 1, config.block_size, config.block_size))
105
+
106
+ def forward(self, x):
107
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
108
+
109
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
110
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
111
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
112
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
113
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
114
+
115
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
116
+ if self.flash:
117
+ # efficient attention using Flash Attention CUDA kernels
118
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
119
+ else:
120
+ # manual implementation of attention
121
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
122
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
123
+ att = F.softmax(att, dim=-1)
124
+ att = self.attn_dropout(att)
125
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
126
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
127
+
128
+ # output projection
129
+ y = self.resid_dropout(self.c_proj(y))
130
+ return y
131
+
132
+
133
+ class MLP(nn.Module):
134
+ def __init__(self, config):
135
+ super().__init__()
136
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
137
+ self.gelu = nn.GELU()
138
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
139
+ self.dropout = nn.Dropout(config.dropout)
140
+
141
+ def forward(self, x):
142
+ x = self.c_fc(x)
143
+ x = self.gelu(x)
144
+ x = self.c_proj(x)
145
+ x = self.dropout(x)
146
+ return x
147
+
148
+
149
+ class MLP_swiglu(nn.Module):
150
+ def __init__(self, config):
151
+ super().__init__()
152
+ self.intermediate_size = int(8 * config.n_embd / 3)
153
+ self.gate_proj = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias)
154
+ self.up_proj = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias)
155
+ self.down_proj = nn.Linear(self.intermediate_size, config.n_embd, bias=config.bias)
156
+ self.act_fn = F.silu
157
+ self.dropout = nn.Dropout(config.dropout)
158
+
159
+ def forward(self, x):
160
+ x = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
161
+ x = self.dropout(x)
162
+ return x
163
+
164
+ class Block(nn.Module):
165
+
166
+ def __init__(self, config):
167
+ super().__init__()
168
+ self.ln_1 = RMSNorm(config.n_embd) if config.use_rmsnorm else LayerNorm(config.n_embd, bias=config.bias)
169
+ self.attn = CausalSelfAttention(config)
170
+ self.ln_2 = RMSNorm(config.n_embd) if config.use_rmsnorm else LayerNorm(config.n_embd, bias=config.bias)
171
+ mlp_cls = MLP_swiglu if config.use_swiglu else MLP
172
+ self.mlp = mlp_cls(config)
173
+
174
+ def forward(self, x):
175
+ x = x + self.attn(self.ln_1(x))
176
+ x = x + self.mlp(self.ln_2(x))
177
+ return x
178
+
179
+
180
+ class BlockCMLP(nn.Module):
181
+
182
+ def __init__(self, config):
183
+ super().__init__()
184
+ self.channel_size = config.block_size
185
+ self.ln_1 = RMSNorm(config.n_embd) if config.use_rmsnorm else LayerNorm(config.n_embd, bias=config.bias)
186
+ self.attn = CausalSelfAttention(config)
187
+ self.ln_2 = RMSNorm(config.n_embd) if config.use_rmsnorm else LayerNorm(config.n_embd, bias=config.bias)
188
+ mlp_cls = MLP_swiglu if config.use_swiglu else MLP
189
+ self.mlps = nn.ModuleList([mlp_cls(config) for _ in range(self.channel_size)])
190
+
191
+ assert self.channel_size == 8, f"DEBUG, self.channel_size={self.channel_size} != 8"
192
+
193
+ def forward(self, x):
194
+ _, channel_size, _ = x.shape
195
+ # assert channel_size == self.channel_size
196
+ x = x + self.attn(self.ln_1(x))
197
+
198
+ xl = self.ln_2(x)
199
+ x = x + torch.cat(
200
+ [self.mlps[c](xl[:, c:c+1, :]) for c in range(self.channel_size)],
201
+ dim=1
202
+ )
203
+ return x
204
+
205
+
206
+ class DepthGPT(nn.Module):
207
+
208
+ def __init__(self, config):
209
+ super().__init__()
210
+ assert config.vocab_size is not None
211
+ assert config.block_size is not None
212
+ self.config = config
213
+ self.num_channel = config.block_size
214
+
215
+ self.linear_in = nn.Linear(config.main_hidden_size, config.n_embd * config.block_size, bias=False)
216
+
217
+ block_cls = BlockCMLP if config.use_cmlp else Block
218
+ self.transformer = nn.ModuleDict(dict(
219
+ wtes = nn.ModuleList([nn.Embedding(config.vocab_size, config.n_embd) for _ in range(self.num_channel)]),
220
+ wpe = nn.Embedding(self.num_channel, config.n_embd),
221
+ drop = nn.Dropout(config.dropout),
222
+ h = nn.ModuleList([block_cls(config) for _ in range(config.n_layer)]),
223
+ ln_f = RMSNorm(config.n_embd) if config.use_rmsnorm else LayerNorm(config.n_embd, bias=config.bias),
224
+ ))
225
+ self.lm_heads = nn.ModuleList([nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(self.num_channel)])
226
+
227
+ # with weight tying when using torch.compile() some warnings get generated:
228
+ # "UserWarning: functional_call was passed multiple values for tied weights.
229
+ # This behavior is deprecated and will be an error in future versions"
230
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
231
+ # self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
232
+
233
+ # init all weights
234
+ self.apply(self._init_weights)
235
+ # apply special scaled init to the residual projections, per GPT-2 paper
236
+ for pn, p in self.named_parameters():
237
+ if pn.endswith('c_proj.weight'):
238
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
239
+
240
+ # report number of parameters
241
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
242
+
243
+ def get_num_params(self, non_embedding=False):
244
+ """
245
+ Return the number of parameters in the model.
246
+ For non-embedding count (default), the position embeddings get subtracted.
247
+ The token embeddings would too, except due to the parameter sharing these
248
+ params are actually used as weights in the final layer, so we include them.
249
+ """
250
+ n_params = sum(p.numel() for p in self.parameters())
251
+ if non_embedding:
252
+ n_params -= self.transformer.wpe.weight.numel()
253
+ return n_params
254
+
255
+ def _init_weights(self, module):
256
+ if isinstance(module, nn.Linear):
257
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
258
+ if module.bias is not None:
259
+ torch.nn.init.zeros_(module.bias)
260
+ elif isinstance(module, nn.Embedding):
261
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
262
+
263
+ def forward(self,
264
+ main_hidden_states, # [seq, main_dim]
265
+ audio_token_ids # [seq, 7]
266
+ ):
267
+
268
+ assert main_hidden_states.shape[0] == audio_token_ids.shape[0]
269
+ in_audio_token_num = audio_token_ids.shape[-1]
270
+
271
+ device = audio_token_ids.device
272
+
273
+ audio_token_ids = F.pad(audio_token_ids, (1, 0), value=self.config.pad_token_id)
274
+
275
+ x = torch.stack(
276
+ [self.transformer.wtes[c](audio_token_ids[:, c]) for c in range(in_audio_token_num + 1)]
277
+ ).transpose(0, 1) # [seq, in_audio_token_num]
278
+
279
+ x += self.transformer.wpe(
280
+ torch.arange(0, in_audio_token_num + 1, dtype=torch.long, device=device)
281
+ ).unsqueeze(0) # position embeddings of shape (1, 8, depth_dim)
282
+
283
+ main_hidden = self.linear_in(main_hidden_states).view(main_hidden_states.shape[0], self.config.block_size, -1)[:, :in_audio_token_num+1, :]
284
+ x += main_hidden
285
+
286
+ x = self.transformer.drop(x)
287
+ for block in self.transformer.h:
288
+ x = block(x)
289
+
290
+ # [seq, 8, hidden]
291
+ x = self.transformer.ln_f(x)
292
+
293
+ # [seq, 8, hidden] (linear)-> [8, seq, vocab]
294
+ x = torch.stack([self.lm_heads[c](x[:, c, :]) for c in range(x.shape[1])])
295
+
296
+ # [8, seq, vocab] -> [seq, 8, vocab]
297
+ x = x.transpose(0,1)
298
+
299
+ return x
300
+ def _initialize_weights(self, module):
301
+ if isinstance(module, nn.Linear):
302
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
303
+ if module.bias is not None:
304
+ torch.nn.init.zeros_(module.bias)
305
+ elif isinstance(module, nn.Embedding):
306
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
307
+
308
+
309
+ if __name__ == "__main__":
310
+ config = {
311
+ "bias": False,
312
+ "dropout": 0.0,
313
+ "n_embd": 1024,
314
+ "n_head": 16,
315
+ "n_layer": 6,
316
+ "use_cmlp": True,
317
+ "use_rmsnorm": True,
318
+ "use_swiglu": True,
319
+ "main_hidden_size": 4096
320
+ }
321
+ model_config = DepthGPTConfig(**config)
322
+ model = DepthGPT(config=model_config)
323
+
324
+ main_hidden_states = torch.rand((1, 4096))
325
+ decoded_audio_tokens = torch.empty((1, 0), dtype=torch.long, device=main_hidden_states.device)
326
+ outputs = model(main_hidden_states, decoded_audio_tokens)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab5cc9261500e9a9cee44e656e97c0f05ce002021bcda732ef3d355a174c2763
3
+ size 4915399120
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99c7ff8c992c5f3f71b6d7c83a415a13de4bc051af71bf446114c30bcd6ddd17
3
+ size 4991495848
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0128fbc7e5107cf0d7cc37bec37ad0f7bffd2cd812f4dec6cafabbd82020429c
3
+ size 4466655904
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4115bbe71b25661e725312c6573ba6a149640d655d2b14339021084cffc7793
3
+ size 2068559752
model.safetensors.index.json ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 8221022720,
4
+ "total_size": 16442045440
5
+ },
6
+ "weight_map": {
7
+ "aud_output_layers.linear_in.weight": "model-00004-of-00004.safetensors",
8
+ "aud_output_layers.lm_heads.0.weight": "model-00004-of-00004.safetensors",
9
+ "aud_output_layers.lm_heads.1.weight": "model-00004-of-00004.safetensors",
10
+ "aud_output_layers.lm_heads.2.weight": "model-00004-of-00004.safetensors",
11
+ "aud_output_layers.lm_heads.3.weight": "model-00004-of-00004.safetensors",
12
+ "aud_output_layers.lm_heads.4.weight": "model-00004-of-00004.safetensors",
13
+ "aud_output_layers.lm_heads.5.weight": "model-00004-of-00004.safetensors",
14
+ "aud_output_layers.lm_heads.6.weight": "model-00004-of-00004.safetensors",
15
+ "aud_output_layers.lm_heads.7.weight": "model-00004-of-00004.safetensors",
16
+ "aud_output_layers.transformer.h.0.attn.c_attn.weight": "model-00004-of-00004.safetensors",
17
+ "aud_output_layers.transformer.h.0.attn.c_proj.weight": "model-00004-of-00004.safetensors",
18
+ "aud_output_layers.transformer.h.0.ln_1.weight": "model-00004-of-00004.safetensors",
19
+ "aud_output_layers.transformer.h.0.ln_2.weight": "model-00004-of-00004.safetensors",
20
+ "aud_output_layers.transformer.h.0.mlps.0.down_proj.weight": "model-00004-of-00004.safetensors",
21
+ "aud_output_layers.transformer.h.0.mlps.0.gate_proj.weight": "model-00004-of-00004.safetensors",
22
+ "aud_output_layers.transformer.h.0.mlps.0.up_proj.weight": "model-00004-of-00004.safetensors",
23
+ "aud_output_layers.transformer.h.0.mlps.1.down_proj.weight": "model-00004-of-00004.safetensors",
24
+ "aud_output_layers.transformer.h.0.mlps.1.gate_proj.weight": "model-00004-of-00004.safetensors",
25
+ "aud_output_layers.transformer.h.0.mlps.1.up_proj.weight": "model-00004-of-00004.safetensors",
26
+ "aud_output_layers.transformer.h.0.mlps.2.down_proj.weight": "model-00004-of-00004.safetensors",
27
+ "aud_output_layers.transformer.h.0.mlps.2.gate_proj.weight": "model-00004-of-00004.safetensors",
28
+ "aud_output_layers.transformer.h.0.mlps.2.up_proj.weight": "model-00004-of-00004.safetensors",
29
+ "aud_output_layers.transformer.h.0.mlps.3.down_proj.weight": "model-00004-of-00004.safetensors",
30
+ "aud_output_layers.transformer.h.0.mlps.3.gate_proj.weight": "model-00004-of-00004.safetensors",
31
+ "aud_output_layers.transformer.h.0.mlps.3.up_proj.weight": "model-00004-of-00004.safetensors",
32
+ "aud_output_layers.transformer.h.0.mlps.4.down_proj.weight": "model-00004-of-00004.safetensors",
33
+ "aud_output_layers.transformer.h.0.mlps.4.gate_proj.weight": "model-00004-of-00004.safetensors",
34
+ "aud_output_layers.transformer.h.0.mlps.4.up_proj.weight": "model-00004-of-00004.safetensors",
35
+ "aud_output_layers.transformer.h.0.mlps.5.down_proj.weight": "model-00004-of-00004.safetensors",
36
+ "aud_output_layers.transformer.h.0.mlps.5.gate_proj.weight": "model-00004-of-00004.safetensors",
37
+ "aud_output_layers.transformer.h.0.mlps.5.up_proj.weight": "model-00004-of-00004.safetensors",
38
+ "aud_output_layers.transformer.h.0.mlps.6.down_proj.weight": "model-00004-of-00004.safetensors",
39
+ "aud_output_layers.transformer.h.0.mlps.6.gate_proj.weight": "model-00004-of-00004.safetensors",
40
+ "aud_output_layers.transformer.h.0.mlps.6.up_proj.weight": "model-00004-of-00004.safetensors",
41
+ "aud_output_layers.transformer.h.0.mlps.7.down_proj.weight": "model-00004-of-00004.safetensors",
42
+ "aud_output_layers.transformer.h.0.mlps.7.gate_proj.weight": "model-00004-of-00004.safetensors",
43
+ "aud_output_layers.transformer.h.0.mlps.7.up_proj.weight": "model-00004-of-00004.safetensors",
44
+ "aud_output_layers.transformer.h.1.attn.c_attn.weight": "model-00004-of-00004.safetensors",
45
+ "aud_output_layers.transformer.h.1.attn.c_proj.weight": "model-00004-of-00004.safetensors",
46
+ "aud_output_layers.transformer.h.1.ln_1.weight": "model-00004-of-00004.safetensors",
47
+ "aud_output_layers.transformer.h.1.ln_2.weight": "model-00004-of-00004.safetensors",
48
+ "aud_output_layers.transformer.h.1.mlps.0.down_proj.weight": "model-00004-of-00004.safetensors",
49
+ "aud_output_layers.transformer.h.1.mlps.0.gate_proj.weight": "model-00004-of-00004.safetensors",
50
+ "aud_output_layers.transformer.h.1.mlps.0.up_proj.weight": "model-00004-of-00004.safetensors",
51
+ "aud_output_layers.transformer.h.1.mlps.1.down_proj.weight": "model-00004-of-00004.safetensors",
52
+ "aud_output_layers.transformer.h.1.mlps.1.gate_proj.weight": "model-00004-of-00004.safetensors",
53
+ "aud_output_layers.transformer.h.1.mlps.1.up_proj.weight": "model-00004-of-00004.safetensors",
54
+ "aud_output_layers.transformer.h.1.mlps.2.down_proj.weight": "model-00004-of-00004.safetensors",
55
+ "aud_output_layers.transformer.h.1.mlps.2.gate_proj.weight": "model-00004-of-00004.safetensors",
56
+ "aud_output_layers.transformer.h.1.mlps.2.up_proj.weight": "model-00004-of-00004.safetensors",
57
+ "aud_output_layers.transformer.h.1.mlps.3.down_proj.weight": "model-00004-of-00004.safetensors",
58
+ "aud_output_layers.transformer.h.1.mlps.3.gate_proj.weight": "model-00004-of-00004.safetensors",
59
+ "aud_output_layers.transformer.h.1.mlps.3.up_proj.weight": "model-00004-of-00004.safetensors",
60
+ "aud_output_layers.transformer.h.1.mlps.4.down_proj.weight": "model-00004-of-00004.safetensors",
61
+ "aud_output_layers.transformer.h.1.mlps.4.gate_proj.weight": "model-00004-of-00004.safetensors",
62
+ "aud_output_layers.transformer.h.1.mlps.4.up_proj.weight": "model-00004-of-00004.safetensors",
63
+ "aud_output_layers.transformer.h.1.mlps.5.down_proj.weight": "model-00004-of-00004.safetensors",
64
+ "aud_output_layers.transformer.h.1.mlps.5.gate_proj.weight": "model-00004-of-00004.safetensors",
65
+ "aud_output_layers.transformer.h.1.mlps.5.up_proj.weight": "model-00004-of-00004.safetensors",
66
+ "aud_output_layers.transformer.h.1.mlps.6.down_proj.weight": "model-00004-of-00004.safetensors",
67
+ "aud_output_layers.transformer.h.1.mlps.6.gate_proj.weight": "model-00004-of-00004.safetensors",
68
+ "aud_output_layers.transformer.h.1.mlps.6.up_proj.weight": "model-00004-of-00004.safetensors",
69
+ "aud_output_layers.transformer.h.1.mlps.7.down_proj.weight": "model-00004-of-00004.safetensors",
70
+ "aud_output_layers.transformer.h.1.mlps.7.gate_proj.weight": "model-00004-of-00004.safetensors",
71
+ "aud_output_layers.transformer.h.1.mlps.7.up_proj.weight": "model-00004-of-00004.safetensors",
72
+ "aud_output_layers.transformer.h.2.attn.c_attn.weight": "model-00004-of-00004.safetensors",
73
+ "aud_output_layers.transformer.h.2.attn.c_proj.weight": "model-00004-of-00004.safetensors",
74
+ "aud_output_layers.transformer.h.2.ln_1.weight": "model-00004-of-00004.safetensors",
75
+ "aud_output_layers.transformer.h.2.ln_2.weight": "model-00004-of-00004.safetensors",
76
+ "aud_output_layers.transformer.h.2.mlps.0.down_proj.weight": "model-00004-of-00004.safetensors",
77
+ "aud_output_layers.transformer.h.2.mlps.0.gate_proj.weight": "model-00004-of-00004.safetensors",
78
+ "aud_output_layers.transformer.h.2.mlps.0.up_proj.weight": "model-00004-of-00004.safetensors",
79
+ "aud_output_layers.transformer.h.2.mlps.1.down_proj.weight": "model-00004-of-00004.safetensors",
80
+ "aud_output_layers.transformer.h.2.mlps.1.gate_proj.weight": "model-00004-of-00004.safetensors",
81
+ "aud_output_layers.transformer.h.2.mlps.1.up_proj.weight": "model-00004-of-00004.safetensors",
82
+ "aud_output_layers.transformer.h.2.mlps.2.down_proj.weight": "model-00004-of-00004.safetensors",
83
+ "aud_output_layers.transformer.h.2.mlps.2.gate_proj.weight": "model-00004-of-00004.safetensors",
84
+ "aud_output_layers.transformer.h.2.mlps.2.up_proj.weight": "model-00004-of-00004.safetensors",
85
+ "aud_output_layers.transformer.h.2.mlps.3.down_proj.weight": "model-00004-of-00004.safetensors",
86
+ "aud_output_layers.transformer.h.2.mlps.3.gate_proj.weight": "model-00004-of-00004.safetensors",
87
+ "aud_output_layers.transformer.h.2.mlps.3.up_proj.weight": "model-00004-of-00004.safetensors",
88
+ "aud_output_layers.transformer.h.2.mlps.4.down_proj.weight": "model-00004-of-00004.safetensors",
89
+ "aud_output_layers.transformer.h.2.mlps.4.gate_proj.weight": "model-00004-of-00004.safetensors",
90
+ "aud_output_layers.transformer.h.2.mlps.4.up_proj.weight": "model-00004-of-00004.safetensors",
91
+ "aud_output_layers.transformer.h.2.mlps.5.down_proj.weight": "model-00004-of-00004.safetensors",
92
+ "aud_output_layers.transformer.h.2.mlps.5.gate_proj.weight": "model-00004-of-00004.safetensors",
93
+ "aud_output_layers.transformer.h.2.mlps.5.up_proj.weight": "model-00004-of-00004.safetensors",
94
+ "aud_output_layers.transformer.h.2.mlps.6.down_proj.weight": "model-00004-of-00004.safetensors",
95
+ "aud_output_layers.transformer.h.2.mlps.6.gate_proj.weight": "model-00004-of-00004.safetensors",
96
+ "aud_output_layers.transformer.h.2.mlps.6.up_proj.weight": "model-00004-of-00004.safetensors",
97
+ "aud_output_layers.transformer.h.2.mlps.7.down_proj.weight": "model-00004-of-00004.safetensors",
98
+ "aud_output_layers.transformer.h.2.mlps.7.gate_proj.weight": "model-00004-of-00004.safetensors",
99
+ "aud_output_layers.transformer.h.2.mlps.7.up_proj.weight": "model-00004-of-00004.safetensors",
100
+ "aud_output_layers.transformer.h.3.attn.c_attn.weight": "model-00004-of-00004.safetensors",
101
+ "aud_output_layers.transformer.h.3.attn.c_proj.weight": "model-00004-of-00004.safetensors",
102
+ "aud_output_layers.transformer.h.3.ln_1.weight": "model-00004-of-00004.safetensors",
103
+ "aud_output_layers.transformer.h.3.ln_2.weight": "model-00004-of-00004.safetensors",
104
+ "aud_output_layers.transformer.h.3.mlps.0.down_proj.weight": "model-00004-of-00004.safetensors",
105
+ "aud_output_layers.transformer.h.3.mlps.0.gate_proj.weight": "model-00004-of-00004.safetensors",
106
+ "aud_output_layers.transformer.h.3.mlps.0.up_proj.weight": "model-00004-of-00004.safetensors",
107
+ "aud_output_layers.transformer.h.3.mlps.1.down_proj.weight": "model-00004-of-00004.safetensors",
108
+ "aud_output_layers.transformer.h.3.mlps.1.gate_proj.weight": "model-00004-of-00004.safetensors",
109
+ "aud_output_layers.transformer.h.3.mlps.1.up_proj.weight": "model-00004-of-00004.safetensors",
110
+ "aud_output_layers.transformer.h.3.mlps.2.down_proj.weight": "model-00004-of-00004.safetensors",
111
+ "aud_output_layers.transformer.h.3.mlps.2.gate_proj.weight": "model-00004-of-00004.safetensors",
112
+ "aud_output_layers.transformer.h.3.mlps.2.up_proj.weight": "model-00004-of-00004.safetensors",
113
+ "aud_output_layers.transformer.h.3.mlps.3.down_proj.weight": "model-00004-of-00004.safetensors",
114
+ "aud_output_layers.transformer.h.3.mlps.3.gate_proj.weight": "model-00004-of-00004.safetensors",
115
+ "aud_output_layers.transformer.h.3.mlps.3.up_proj.weight": "model-00004-of-00004.safetensors",
116
+ "aud_output_layers.transformer.h.3.mlps.4.down_proj.weight": "model-00004-of-00004.safetensors",
117
+ "aud_output_layers.transformer.h.3.mlps.4.gate_proj.weight": "model-00004-of-00004.safetensors",
118
+ "aud_output_layers.transformer.h.3.mlps.4.up_proj.weight": "model-00004-of-00004.safetensors",
119
+ "aud_output_layers.transformer.h.3.mlps.5.down_proj.weight": "model-00004-of-00004.safetensors",
120
+ "aud_output_layers.transformer.h.3.mlps.5.gate_proj.weight": "model-00004-of-00004.safetensors",
121
+ "aud_output_layers.transformer.h.3.mlps.5.up_proj.weight": "model-00004-of-00004.safetensors",
122
+ "aud_output_layers.transformer.h.3.mlps.6.down_proj.weight": "model-00004-of-00004.safetensors",
123
+ "aud_output_layers.transformer.h.3.mlps.6.gate_proj.weight": "model-00004-of-00004.safetensors",
124
+ "aud_output_layers.transformer.h.3.mlps.6.up_proj.weight": "model-00004-of-00004.safetensors",
125
+ "aud_output_layers.transformer.h.3.mlps.7.down_proj.weight": "model-00004-of-00004.safetensors",
126
+ "aud_output_layers.transformer.h.3.mlps.7.gate_proj.weight": "model-00004-of-00004.safetensors",
127
+ "aud_output_layers.transformer.h.3.mlps.7.up_proj.weight": "model-00004-of-00004.safetensors",
128
+ "aud_output_layers.transformer.h.4.attn.c_attn.weight": "model-00004-of-00004.safetensors",
129
+ "aud_output_layers.transformer.h.4.attn.c_proj.weight": "model-00004-of-00004.safetensors",
130
+ "aud_output_layers.transformer.h.4.ln_1.weight": "model-00004-of-00004.safetensors",
131
+ "aud_output_layers.transformer.h.4.ln_2.weight": "model-00004-of-00004.safetensors",
132
+ "aud_output_layers.transformer.h.4.mlps.0.down_proj.weight": "model-00004-of-00004.safetensors",
133
+ "aud_output_layers.transformer.h.4.mlps.0.gate_proj.weight": "model-00004-of-00004.safetensors",
134
+ "aud_output_layers.transformer.h.4.mlps.0.up_proj.weight": "model-00004-of-00004.safetensors",
135
+ "aud_output_layers.transformer.h.4.mlps.1.down_proj.weight": "model-00004-of-00004.safetensors",
136
+ "aud_output_layers.transformer.h.4.mlps.1.gate_proj.weight": "model-00004-of-00004.safetensors",
137
+ "aud_output_layers.transformer.h.4.mlps.1.up_proj.weight": "model-00004-of-00004.safetensors",
138
+ "aud_output_layers.transformer.h.4.mlps.2.down_proj.weight": "model-00004-of-00004.safetensors",
139
+ "aud_output_layers.transformer.h.4.mlps.2.gate_proj.weight": "model-00004-of-00004.safetensors",
140
+ "aud_output_layers.transformer.h.4.mlps.2.up_proj.weight": "model-00004-of-00004.safetensors",
141
+ "aud_output_layers.transformer.h.4.mlps.3.down_proj.weight": "model-00004-of-00004.safetensors",
142
+ "aud_output_layers.transformer.h.4.mlps.3.gate_proj.weight": "model-00004-of-00004.safetensors",
143
+ "aud_output_layers.transformer.h.4.mlps.3.up_proj.weight": "model-00004-of-00004.safetensors",
144
+ "aud_output_layers.transformer.h.4.mlps.4.down_proj.weight": "model-00004-of-00004.safetensors",
145
+ "aud_output_layers.transformer.h.4.mlps.4.gate_proj.weight": "model-00004-of-00004.safetensors",
146
+ "aud_output_layers.transformer.h.4.mlps.4.up_proj.weight": "model-00004-of-00004.safetensors",
147
+ "aud_output_layers.transformer.h.4.mlps.5.down_proj.weight": "model-00004-of-00004.safetensors",
148
+ "aud_output_layers.transformer.h.4.mlps.5.gate_proj.weight": "model-00004-of-00004.safetensors",
149
+ "aud_output_layers.transformer.h.4.mlps.5.up_proj.weight": "model-00004-of-00004.safetensors",
150
+ "aud_output_layers.transformer.h.4.mlps.6.down_proj.weight": "model-00004-of-00004.safetensors",
151
+ "aud_output_layers.transformer.h.4.mlps.6.gate_proj.weight": "model-00004-of-00004.safetensors",
152
+ "aud_output_layers.transformer.h.4.mlps.6.up_proj.weight": "model-00004-of-00004.safetensors",
153
+ "aud_output_layers.transformer.h.4.mlps.7.down_proj.weight": "model-00004-of-00004.safetensors",
154
+ "aud_output_layers.transformer.h.4.mlps.7.gate_proj.weight": "model-00004-of-00004.safetensors",
155
+ "aud_output_layers.transformer.h.4.mlps.7.up_proj.weight": "model-00004-of-00004.safetensors",
156
+ "aud_output_layers.transformer.h.5.attn.c_attn.weight": "model-00004-of-00004.safetensors",
157
+ "aud_output_layers.transformer.h.5.attn.c_proj.weight": "model-00004-of-00004.safetensors",
158
+ "aud_output_layers.transformer.h.5.ln_1.weight": "model-00004-of-00004.safetensors",
159
+ "aud_output_layers.transformer.h.5.ln_2.weight": "model-00004-of-00004.safetensors",
160
+ "aud_output_layers.transformer.h.5.mlps.0.down_proj.weight": "model-00004-of-00004.safetensors",
161
+ "aud_output_layers.transformer.h.5.mlps.0.gate_proj.weight": "model-00004-of-00004.safetensors",
162
+ "aud_output_layers.transformer.h.5.mlps.0.up_proj.weight": "model-00004-of-00004.safetensors",
163
+ "aud_output_layers.transformer.h.5.mlps.1.down_proj.weight": "model-00004-of-00004.safetensors",
164
+ "aud_output_layers.transformer.h.5.mlps.1.gate_proj.weight": "model-00004-of-00004.safetensors",
165
+ "aud_output_layers.transformer.h.5.mlps.1.up_proj.weight": "model-00004-of-00004.safetensors",
166
+ "aud_output_layers.transformer.h.5.mlps.2.down_proj.weight": "model-00004-of-00004.safetensors",
167
+ "aud_output_layers.transformer.h.5.mlps.2.gate_proj.weight": "model-00004-of-00004.safetensors",
168
+ "aud_output_layers.transformer.h.5.mlps.2.up_proj.weight": "model-00004-of-00004.safetensors",
169
+ "aud_output_layers.transformer.h.5.mlps.3.down_proj.weight": "model-00004-of-00004.safetensors",
170
+ "aud_output_layers.transformer.h.5.mlps.3.gate_proj.weight": "model-00004-of-00004.safetensors",
171
+ "aud_output_layers.transformer.h.5.mlps.3.up_proj.weight": "model-00004-of-00004.safetensors",
172
+ "aud_output_layers.transformer.h.5.mlps.4.down_proj.weight": "model-00004-of-00004.safetensors",
173
+ "aud_output_layers.transformer.h.5.mlps.4.gate_proj.weight": "model-00004-of-00004.safetensors",
174
+ "aud_output_layers.transformer.h.5.mlps.4.up_proj.weight": "model-00004-of-00004.safetensors",
175
+ "aud_output_layers.transformer.h.5.mlps.5.down_proj.weight": "model-00004-of-00004.safetensors",
176
+ "aud_output_layers.transformer.h.5.mlps.5.gate_proj.weight": "model-00004-of-00004.safetensors",
177
+ "aud_output_layers.transformer.h.5.mlps.5.up_proj.weight": "model-00004-of-00004.safetensors",
178
+ "aud_output_layers.transformer.h.5.mlps.6.down_proj.weight": "model-00004-of-00004.safetensors",
179
+ "aud_output_layers.transformer.h.5.mlps.6.gate_proj.weight": "model-00004-of-00004.safetensors",
180
+ "aud_output_layers.transformer.h.5.mlps.6.up_proj.weight": "model-00004-of-00004.safetensors",
181
+ "aud_output_layers.transformer.h.5.mlps.7.down_proj.weight": "model-00004-of-00004.safetensors",
182
+ "aud_output_layers.transformer.h.5.mlps.7.gate_proj.weight": "model-00004-of-00004.safetensors",
183
+ "aud_output_layers.transformer.h.5.mlps.7.up_proj.weight": "model-00004-of-00004.safetensors",
184
+ "aud_output_layers.transformer.ln_f.weight": "model-00004-of-00004.safetensors",
185
+ "aud_output_layers.transformer.wpe.weight": "model-00004-of-00004.safetensors",
186
+ "aud_output_layers.transformer.wtes.0.weight": "model-00004-of-00004.safetensors",
187
+ "aud_output_layers.transformer.wtes.1.weight": "model-00004-of-00004.safetensors",
188
+ "aud_output_layers.transformer.wtes.2.weight": "model-00004-of-00004.safetensors",
189
+ "aud_output_layers.transformer.wtes.3.weight": "model-00004-of-00004.safetensors",
190
+ "aud_output_layers.transformer.wtes.4.weight": "model-00004-of-00004.safetensors",
191
+ "aud_output_layers.transformer.wtes.5.weight": "model-00004-of-00004.safetensors",
192
+ "aud_output_layers.transformer.wtes.6.weight": "model-00004-of-00004.safetensors",
193
+ "aud_output_layers.transformer.wtes.7.weight": "model-00004-of-00004.safetensors",
194
+ "lm_head.weight": "model-00004-of-00004.safetensors",
195
+ "model.embed_tokens.aud_listen_embeddings.0.weight": "model-00001-of-00004.safetensors",
196
+ "model.embed_tokens.aud_listen_embeddings.1.weight": "model-00001-of-00004.safetensors",
197
+ "model.embed_tokens.aud_listen_embeddings.2.weight": "model-00001-of-00004.safetensors",
198
+ "model.embed_tokens.aud_listen_embeddings.3.weight": "model-00001-of-00004.safetensors",
199
+ "model.embed_tokens.aud_listen_embeddings.4.weight": "model-00001-of-00004.safetensors",
200
+ "model.embed_tokens.aud_listen_embeddings.5.weight": "model-00001-of-00004.safetensors",
201
+ "model.embed_tokens.aud_listen_embeddings.6.weight": "model-00001-of-00004.safetensors",
202
+ "model.embed_tokens.aud_listen_embeddings.7.weight": "model-00001-of-00004.safetensors",
203
+ "model.embed_tokens.aud_speak_embeddings.0.weight": "model-00001-of-00004.safetensors",
204
+ "model.embed_tokens.aud_speak_embeddings.1.weight": "model-00001-of-00004.safetensors",
205
+ "model.embed_tokens.aud_speak_embeddings.2.weight": "model-00001-of-00004.safetensors",
206
+ "model.embed_tokens.aud_speak_embeddings.3.weight": "model-00001-of-00004.safetensors",
207
+ "model.embed_tokens.aud_speak_embeddings.4.weight": "model-00001-of-00004.safetensors",
208
+ "model.embed_tokens.aud_speak_embeddings.5.weight": "model-00001-of-00004.safetensors",
209
+ "model.embed_tokens.aud_speak_embeddings.6.weight": "model-00001-of-00004.safetensors",
210
+ "model.embed_tokens.aud_speak_embeddings.7.weight": "model-00001-of-00004.safetensors",
211
+ "model.embed_tokens.text_embeddings.weight": "model-00001-of-00004.safetensors",
212
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
213
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
214
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
215
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
216
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
217
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
218
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
219
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
220
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
221
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
222
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
223
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
224
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
225
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
226
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
227
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
228
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
229
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
230
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
231
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
232
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
233
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
234
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
235
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
236
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
237
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
238
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
239
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
240
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
241
+ "model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
242
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
243
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
244
+ "model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
245
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
246
+ "model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
247
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
248
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
249
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
250
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
251
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
252
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
253
+ "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
254
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
255
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
256
+ "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
257
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
258
+ "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
259
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
260
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
261
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
262
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
263
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
264
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
265
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
266
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
267
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
268
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
269
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
270
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
271
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
272
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
273
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
274
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
275
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
276
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
277
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
278
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
279
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
280
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
281
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
282
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
283
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
284
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
285
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
286
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
287
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
288
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
289
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
290
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
291
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
292
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
293
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
294
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
295
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
296
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
297
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
298
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
299
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
300
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
301
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
302
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
303
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
304
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
305
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
306
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
307
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
308
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
309
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
310
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
311
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
312
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
313
+ "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
314
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
315
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
316
+ "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
317
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
318
+ "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
319
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
320
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
321
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
322
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
323
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
324
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
325
+ "model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
326
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
327
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
328
+ "model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
329
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
330
+ "model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
331
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
332
+ "model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
333
+ "model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
334
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
335
+ "model.layers.18.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
336
+ "model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
337
+ "model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
338
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
339
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
340
+ "model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
341
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
342
+ "model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
343
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
344
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
345
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
346
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
347
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
348
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
349
+ "model.layers.19.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
350
+ "model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
351
+ "model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
352
+ "model.layers.19.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
353
+ "model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
354
+ "model.layers.19.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
355
+ "model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
356
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
357
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
358
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
359
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
360
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
361
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
362
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
363
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
364
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
365
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
366
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
367
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
368
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
369
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
370
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
371
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
372
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
373
+ "model.layers.20.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
374
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
375
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
376
+ "model.layers.20.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
377
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
378
+ "model.layers.20.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
379
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
380
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
381
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
382
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
383
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
384
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
385
+ "model.layers.21.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
386
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
387
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
388
+ "model.layers.21.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
389
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
390
+ "model.layers.21.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
391
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
392
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
393
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
394
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
395
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
396
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
397
+ "model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
398
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
399
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
400
+ "model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
401
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
402
+ "model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
403
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
404
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
405
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
406
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
407
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
408
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
409
+ "model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
410
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
411
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
412
+ "model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
413
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
414
+ "model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
415
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
416
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
417
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
418
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
419
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
420
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
421
+ "model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
422
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
423
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
424
+ "model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
425
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
426
+ "model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
427
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
428
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
429
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
430
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
431
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
432
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
433
+ "model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
434
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
435
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
436
+ "model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
437
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
438
+ "model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
439
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
440
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
441
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
442
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
443
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
444
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
445
+ "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
446
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
447
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
448
+ "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
449
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
450
+ "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
451
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
452
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
453
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
454
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
455
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
456
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
457
+ "model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
458
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
459
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
460
+ "model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
461
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
462
+ "model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
463
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
464
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
465
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
466
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
467
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
468
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
469
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
470
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
471
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
472
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
473
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
474
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
475
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
476
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
477
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
478
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
479
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
480
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
481
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
482
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
483
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
484
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
485
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
486
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
487
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
488
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
489
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
490
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
491
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
492
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
493
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
494
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
495
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
496
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
497
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
498
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
499
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
500
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
501
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
502
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
503
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
504
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
505
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
506
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
507
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
508
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
509
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
510
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
511
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
512
+ "model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
513
+ "model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
514
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
515
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
516
+ "model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
517
+ "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
518
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
519
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
520
+ "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
521
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
522
+ "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
523
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
524
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
525
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
526
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
527
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
528
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
529
+ "model.layers.8.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
530
+ "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
531
+ "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
532
+ "model.layers.8.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
533
+ "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
534
+ "model.layers.8.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
535
+ "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
536
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
537
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
538
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
539
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
540
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
541
+ "model.layers.9.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
542
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
543
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
544
+ "model.layers.9.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
545
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
546
+ "model.layers.9.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
547
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
548
+ "model.norm.weight": "model-00003-of-00004.safetensors"
549
+ }
550
+ }
modeling_flmaudio.py ADDED
@@ -0,0 +1,1524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """PyTorch FLM-Audio model, based on LLAMA implementation."""
3
+
4
+ import math
5
+ import warnings
6
+ from typing import List, Optional, Tuple, Union
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from transformers.activations import ACT2FN
14
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
15
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
16
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
17
+ from transformers.modeling_outputs import (
18
+ ModelOutput,
19
+ BaseModelOutputWithPast,
20
+ CausalLMOutputWithPast,
21
+ )
22
+ from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
24
+ from transformers.utils import (
25
+ add_start_docstrings,
26
+ add_start_docstrings_to_model_forward,
27
+ is_flash_attn_2_available,
28
+ is_flash_attn_greater_or_equal_2_10,
29
+ logging,
30
+ replace_return_docstrings,
31
+ )
32
+ from .configuration_flmaudio import FLMAudioConfig
33
+ from .depth_gpt import DepthGPT, DepthGPTConfig
34
+
35
+ if is_flash_attn_2_available():
36
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
37
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CONFIG_FOR_DOC = "FLMAudioConfig"
43
+
44
+
45
+ def _get_unpad_data(attention_mask):
46
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
47
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
48
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
49
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
50
+ return (
51
+ indices,
52
+ cu_seqlens,
53
+ max_seqlen_in_batch,
54
+ )
55
+
56
+
57
+ class FLMAudioRMSNorm(nn.Module):
58
+ def __init__(self, hidden_size, eps=1e-6):
59
+ """
60
+ FLMAudioRMSNorm is equivalent to T5LayerNorm
61
+ """
62
+ super().__init__()
63
+ self.weight = nn.Parameter(torch.ones(hidden_size))
64
+ self.variance_epsilon = eps
65
+
66
+ def forward(self, hidden_states):
67
+ input_dtype = hidden_states.dtype
68
+ hidden_states = hidden_states.to(torch.float32)
69
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
70
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
71
+ return self.weight * hidden_states.to(input_dtype)
72
+
73
+
74
+ ALL_LAYERNORM_LAYERS.append(FLMAudioRMSNorm)
75
+
76
+ class FLMAudioRotaryEmbedding(nn.Module):
77
+ def __init__(self, config, device=None):
78
+ super().__init__()
79
+ # BC: "rope_type" was originally "type"
80
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
81
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
82
+ else:
83
+ self.rope_type = "default"
84
+ self.max_seq_len_cached = config.max_position_embeddings
85
+ self.original_max_seq_len = config.max_position_embeddings
86
+
87
+ self.config = config
88
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
89
+
90
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
91
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
92
+ self.original_inv_freq = self.inv_freq
93
+
94
+ @torch.no_grad()
95
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
96
+ def forward(self, x, position_ids):
97
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
98
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
99
+
100
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
101
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
102
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
103
+ emb = torch.cat((freqs, freqs), dim=-1)
104
+ cos = emb.cos() * self.attention_scaling
105
+ sin = emb.sin() * self.attention_scaling
106
+
107
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
108
+
109
+
110
+ def rotate_half(x):
111
+ """Rotates half the hidden dims of the input."""
112
+ x1 = x[..., : x.shape[-1] // 2]
113
+ x2 = x[..., x.shape[-1] // 2 :]
114
+ return torch.cat((-x2, x1), dim=-1)
115
+
116
+
117
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
118
+ mrope_section = mrope_section * 2
119
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
120
+ unsqueeze_dim
121
+ )
122
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
123
+ unsqueeze_dim
124
+ )
125
+
126
+ q_embed = (q * cos) + (rotate_half(q) * sin)
127
+ k_embed = (k * cos) + (rotate_half(k) * sin)
128
+ return q_embed, k_embed
129
+
130
+
131
+ class FLMAudioMLP(nn.Module):
132
+ def __init__(self, config):
133
+ super().__init__()
134
+ self.config = config
135
+ self.hidden_size = config.hidden_size
136
+ self.intermediate_size = config.intermediate_size
137
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
138
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
139
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
140
+ self.act_fn = ACT2FN[config.hidden_act]
141
+
142
+ def forward(self, x):
143
+ if self.config.pretraining_tp > 1:
144
+ slice = self.intermediate_size // self.config.pretraining_tp
145
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
146
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
147
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
148
+
149
+ gate_proj = torch.cat(
150
+ [
151
+ F.linear(x, gate_proj_slices[i])
152
+ for i in range(self.config.pretraining_tp)
153
+ ],
154
+ dim=-1,
155
+ )
156
+ up_proj = torch.cat(
157
+ [
158
+ F.linear(x, up_proj_slices[i])
159
+ for i in range(self.config.pretraining_tp)
160
+ ],
161
+ dim=-1,
162
+ )
163
+
164
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
165
+ down_proj = [
166
+ F.linear(intermediate_states[i], down_proj_slices[i])
167
+ for i in range(self.config.pretraining_tp)
168
+ ]
169
+ down_proj = sum(down_proj)
170
+ else:
171
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
172
+
173
+ return down_proj
174
+
175
+
176
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
177
+ """
178
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
179
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
180
+ """
181
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
182
+ if n_rep == 1:
183
+ return hidden_states
184
+ hidden_states = hidden_states[:, :, None, :, :].expand(
185
+ batch, num_key_value_heads, n_rep, slen, head_dim
186
+ )
187
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
188
+
189
+
190
+ class FLMAudioAttention(nn.Module):
191
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
192
+
193
+ def __init__(self, config: FLMAudioConfig, layer_idx: Optional[int] = None):
194
+ super().__init__()
195
+ self.config = config
196
+ self.layer_idx = layer_idx
197
+ if layer_idx is None:
198
+ logger.warning_once(
199
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
200
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
201
+ "when creating this class."
202
+ )
203
+
204
+ self.attention_dropout = config.attention_dropout
205
+ self.hidden_size = config.hidden_size
206
+ self.num_heads = config.num_attention_heads
207
+ self.head_dim = self.hidden_size // self.num_heads
208
+ self.num_key_value_heads = config.num_key_value_heads
209
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
210
+ self.is_causal = True
211
+ self.rope_scaling = config.rope_scaling
212
+
213
+ if (self.head_dim * self.num_heads) != self.hidden_size:
214
+ raise ValueError(
215
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
216
+ f" and `num_heads`: {self.num_heads})."
217
+ )
218
+
219
+ self.q_proj = nn.Linear(
220
+ self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
221
+ )
222
+ self.k_proj = nn.Linear(
223
+ self.hidden_size,
224
+ self.num_key_value_heads * self.head_dim,
225
+ bias=config.attention_bias,
226
+ )
227
+ self.v_proj = nn.Linear(
228
+ self.hidden_size,
229
+ self.num_key_value_heads * self.head_dim,
230
+ bias=config.attention_bias,
231
+ )
232
+ self.o_proj = nn.Linear(
233
+ self.hidden_size, self.hidden_size, bias=config.attention_bias and not config.disable_att_o_bias
234
+ )
235
+
236
+
237
+ def forward(
238
+ self,
239
+ hidden_states: torch.Tensor,
240
+ attention_mask: Optional[torch.Tensor] = None,
241
+ position_ids: Optional[torch.LongTensor] = None,
242
+ past_key_value: Optional[Cache] = None,
243
+ output_attentions: bool = False,
244
+ use_cache: bool = False,
245
+ cache_position: Optional[torch.LongTensor] = None,
246
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
247
+ **kwargs,
248
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
249
+ bsz, q_len, _ = hidden_states.size()
250
+
251
+ if self.config.pretraining_tp > 1:
252
+ key_value_slicing = (
253
+ self.num_key_value_heads * self.head_dim
254
+ ) // self.config.pretraining_tp
255
+ query_slices = self.q_proj.weight.split(
256
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
257
+ )
258
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
259
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
260
+
261
+ query_states = [
262
+ F.linear(hidden_states, query_slices[i])
263
+ for i in range(self.config.pretraining_tp)
264
+ ]
265
+ query_states = torch.cat(query_states, dim=-1)
266
+
267
+ key_states = [
268
+ F.linear(hidden_states, key_slices[i])
269
+ for i in range(self.config.pretraining_tp)
270
+ ]
271
+ key_states = torch.cat(key_states, dim=-1)
272
+
273
+ value_states = [
274
+ F.linear(hidden_states, value_slices[i])
275
+ for i in range(self.config.pretraining_tp)
276
+ ]
277
+ value_states = torch.cat(value_states, dim=-1)
278
+
279
+ else:
280
+ query_states = self.q_proj(hidden_states)
281
+ key_states = self.k_proj(hidden_states)
282
+ value_states = self.v_proj(hidden_states)
283
+
284
+ query_states = query_states.view(
285
+ bsz, q_len, self.num_heads, self.head_dim
286
+ ).transpose(1, 2)
287
+ key_states = key_states.view(
288
+ bsz, q_len, self.num_key_value_heads, self.head_dim
289
+ ).transpose(1, 2)
290
+ value_states = value_states.view(
291
+ bsz, q_len, self.num_key_value_heads, self.head_dim
292
+ ).transpose(1, 2)
293
+
294
+ cos, sin = position_embeddings
295
+
296
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
297
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
298
+ )
299
+
300
+ if past_key_value is not None:
301
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
302
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
303
+ key_states, value_states = past_key_value.update(
304
+ key_states, value_states, self.layer_idx, cache_kwargs
305
+ )
306
+
307
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
308
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
309
+
310
+ attn_weights = torch.matmul(
311
+ query_states, key_states.transpose(2, 3)
312
+ ) / math.sqrt(self.head_dim)
313
+
314
+ if attention_mask is not None: # no matter the length, we just slice it
315
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
316
+ attn_weights = attn_weights + causal_mask
317
+
318
+ # upcast attention to fp32
319
+ attn_weights = nn.functional.softmax(
320
+ attn_weights, dim=-1, dtype=torch.float32
321
+ ).to(query_states.dtype)
322
+ attn_weights = nn.functional.dropout(
323
+ attn_weights, p=self.attention_dropout, training=self.training
324
+ )
325
+ attn_output = torch.matmul(attn_weights, value_states)
326
+
327
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
328
+ raise ValueError(
329
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
330
+ f" {attn_output.size()}"
331
+ )
332
+
333
+ attn_output = attn_output.transpose(1, 2).contiguous()
334
+
335
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
336
+
337
+ if self.config.pretraining_tp > 1:
338
+ attn_output = attn_output.split(
339
+ self.hidden_size // self.config.pretraining_tp, dim=2
340
+ )
341
+ o_proj_slices = self.o_proj.weight.split(
342
+ self.hidden_size // self.config.pretraining_tp, dim=1
343
+ )
344
+ attn_output = sum(
345
+ [
346
+ F.linear(attn_output[i], o_proj_slices[i])
347
+ for i in range(self.config.pretraining_tp)
348
+ ]
349
+ )
350
+ else:
351
+ attn_output = self.o_proj(attn_output)
352
+
353
+ if not output_attentions:
354
+ attn_weights = None
355
+
356
+ return attn_output, attn_weights, past_key_value
357
+
358
+
359
+ class FLMAudioFlashAttention2(FLMAudioAttention):
360
+ """
361
+ FLM-Audio flash attention module. This module inherits from `FLMAudioAttention` as the weights of the module stays
362
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
363
+ flash attention and deal with padding tokens in case the input contains any of them.
364
+ """
365
+
366
+ def __init__(self, *args, **kwargs):
367
+ super().__init__(*args, **kwargs)
368
+
369
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
370
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
371
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
372
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
373
+
374
+ def forward(
375
+ self,
376
+ hidden_states: torch.Tensor,
377
+ attention_mask: Optional[torch.LongTensor] = None,
378
+ position_ids: Optional[torch.LongTensor] = None,
379
+ past_key_value: Optional[Cache] = None,
380
+ output_attentions: bool = False,
381
+ use_cache: bool = False,
382
+ cache_position: Optional[torch.LongTensor] = None,
383
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
384
+ **kwargs,
385
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
386
+ output_attentions = False
387
+
388
+ bsz, q_len, _ = hidden_states.size()
389
+
390
+ query_states = self.q_proj(hidden_states)
391
+ key_states = self.k_proj(hidden_states)
392
+ value_states = self.v_proj(hidden_states)
393
+
394
+ # Flash attention requires the input to have the shape
395
+ # batch_size x seq_length x head_dim x hidden_dim
396
+ # therefore we just need to keep the original shape
397
+ query_states = query_states.view(
398
+ bsz, q_len, self.num_heads, self.head_dim
399
+ ).transpose(1, 2)
400
+ key_states = key_states.view(
401
+ bsz, q_len, self.num_key_value_heads, self.head_dim
402
+ ).transpose(1, 2)
403
+ value_states = value_states.view(
404
+ bsz, q_len, self.num_key_value_heads, self.head_dim
405
+ ).transpose(1, 2)
406
+
407
+ # cos, sin = self.rotary_emb(value_states, position_ids)
408
+ cos, sin = position_embeddings
409
+ # query_states, key_states = apply_rotary_pos_emb(
410
+ # query_states, key_states, cos, sin
411
+ # )
412
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
413
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
414
+ )
415
+
416
+ past_key_value = getattr(self, "past_key_value", past_key_value)
417
+
418
+ if past_key_value is not None:
419
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
420
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
421
+ key_states, value_states = past_key_value.update(
422
+ key_states, value_states, self.layer_idx, cache_kwargs
423
+ )
424
+
425
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
426
+ # to be able to avoid many of these transpose/reshape/view.
427
+ query_states = query_states.transpose(1, 2)
428
+ key_states = key_states.transpose(1, 2)
429
+ value_states = value_states.transpose(1, 2)
430
+
431
+ dropout_rate = self.attention_dropout if self.training else 0.0
432
+
433
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
434
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
435
+ # cast them back in the correct dtype just to be sure everything works as expected.
436
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
437
+ # in fp32. (FLMAudioRMSNorm handles it correctly)
438
+
439
+ input_dtype = query_states.dtype
440
+ if input_dtype == torch.float32:
441
+ if torch.is_autocast_enabled():
442
+ target_dtype = torch.get_autocast_gpu_dtype()
443
+ # Handle the case where the model is quantized
444
+ elif hasattr(self.config, "_pre_quantization_dtype"):
445
+ target_dtype = self.config._pre_quantization_dtype
446
+ else:
447
+ target_dtype = self.q_proj.weight.dtype
448
+
449
+ logger.warning_once(
450
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
451
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
452
+ f" {target_dtype}."
453
+ )
454
+
455
+ query_states = query_states.to(target_dtype)
456
+ key_states = key_states.to(target_dtype)
457
+ value_states = value_states.to(target_dtype)
458
+
459
+ attn_output = self._flash_attention_forward(
460
+ query_states,
461
+ key_states,
462
+ value_states,
463
+ attention_mask,
464
+ q_len,
465
+ dropout=dropout_rate,
466
+ )
467
+
468
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
469
+ attn_output = self.o_proj(attn_output)
470
+
471
+ if not output_attentions:
472
+ attn_weights = None
473
+
474
+ return attn_output, attn_weights, past_key_value
475
+
476
+ def _flash_attention_forward(
477
+ self,
478
+ query_states,
479
+ key_states,
480
+ value_states,
481
+ attention_mask,
482
+ query_length,
483
+ dropout=0.0,
484
+ softmax_scale=None,
485
+ ):
486
+ """
487
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
488
+ first unpad the input, then computes the attention scores and pad the final attention scores.
489
+
490
+ Args:
491
+ query_states (`torch.Tensor`):
492
+ Input query states to be passed to Flash Attention API
493
+ key_states (`torch.Tensor`):
494
+ Input key states to be passed to Flash Attention API
495
+ value_states (`torch.Tensor`):
496
+ Input value states to be passed to Flash Attention API
497
+ attention_mask (`torch.Tensor`):
498
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
499
+ position of padding tokens and 1 for the position of non-padding tokens.
500
+ dropout (`float`):
501
+ Attention dropout
502
+ softmax_scale (`float`, *optional*):
503
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
504
+ """
505
+ if not self._flash_attn_uses_top_left_mask:
506
+ causal = self.is_causal
507
+ else:
508
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in FLMAudioFlashAttention2 __init__.
509
+ causal = self.is_causal and query_length != 1
510
+
511
+ # Contains at least one padding token in the sequence
512
+ if attention_mask is not None:
513
+ batch_size = query_states.shape[0]
514
+ (
515
+ query_states,
516
+ key_states,
517
+ value_states,
518
+ indices_q,
519
+ cu_seq_lens,
520
+ max_seq_lens,
521
+ ) = self._upad_input(
522
+ query_states, key_states, value_states, attention_mask, query_length
523
+ )
524
+
525
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
526
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
527
+
528
+ attn_output_unpad = flash_attn_varlen_func(
529
+ query_states,
530
+ key_states,
531
+ value_states,
532
+ cu_seqlens_q=cu_seqlens_q,
533
+ cu_seqlens_k=cu_seqlens_k,
534
+ max_seqlen_q=max_seqlen_in_batch_q,
535
+ max_seqlen_k=max_seqlen_in_batch_k,
536
+ dropout_p=dropout,
537
+ softmax_scale=softmax_scale,
538
+ causal=causal,
539
+ )
540
+
541
+ attn_output = pad_input(
542
+ attn_output_unpad, indices_q, batch_size, query_length
543
+ )
544
+ else:
545
+ attn_output = flash_attn_func(
546
+ query_states,
547
+ key_states,
548
+ value_states,
549
+ dropout,
550
+ softmax_scale=softmax_scale,
551
+ causal=causal,
552
+ )
553
+
554
+ return attn_output
555
+
556
+ def _upad_input(
557
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
558
+ ):
559
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
560
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
561
+
562
+ key_layer = index_first_axis(
563
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
564
+ indices_k,
565
+ )
566
+ value_layer = index_first_axis(
567
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
568
+ indices_k,
569
+ )
570
+ if query_length == kv_seq_len:
571
+ query_layer = index_first_axis(
572
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
573
+ indices_k,
574
+ )
575
+ cu_seqlens_q = cu_seqlens_k
576
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
577
+ indices_q = indices_k
578
+ elif query_length == 1:
579
+ max_seqlen_in_batch_q = 1
580
+ cu_seqlens_q = torch.arange(
581
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
582
+ ) # There is a memcpy here, that is very bad.
583
+ indices_q = cu_seqlens_q[:-1]
584
+ query_layer = query_layer.squeeze(1)
585
+ else:
586
+ # The -q_len: slice assumes left padding.
587
+ attention_mask = attention_mask[:, -query_length:]
588
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
589
+ query_layer, attention_mask
590
+ )
591
+
592
+ return (
593
+ query_layer,
594
+ key_layer,
595
+ value_layer,
596
+ indices_q,
597
+ (cu_seqlens_q, cu_seqlens_k),
598
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
599
+ )
600
+
601
+
602
+ class FLMAudioSdpaAttention(FLMAudioAttention):
603
+ """
604
+ FLM-Audio attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
605
+ `FLMAudioAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
606
+ SDPA API.
607
+ """
608
+
609
+ # Adapted from FLMAudioAttention.forward
610
+ def forward(
611
+ self,
612
+ hidden_states: torch.Tensor,
613
+ attention_mask: Optional[torch.Tensor] = None,
614
+ position_ids: Optional[torch.LongTensor] = None,
615
+ past_key_value: Optional[Cache] = None,
616
+ output_attentions: bool = False,
617
+ use_cache: bool = False,
618
+ cache_position: Optional[torch.LongTensor] = None,
619
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
620
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
621
+ if output_attentions:
622
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
623
+ logger.warning_once(
624
+ "FLMAudioModel is using FLMAudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
625
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
626
+ )
627
+ return super().forward(
628
+ hidden_states=hidden_states,
629
+ attention_mask=attention_mask,
630
+ position_ids=position_ids,
631
+ past_key_value=past_key_value,
632
+ output_attentions=output_attentions,
633
+ use_cache=use_cache,
634
+ cache_position=cache_position,
635
+ )
636
+
637
+ bsz, q_len, _ = hidden_states.size()
638
+
639
+ query_states = self.q_proj(hidden_states)
640
+ key_states = self.k_proj(hidden_states)
641
+ value_states = self.v_proj(hidden_states)
642
+
643
+ query_states = query_states.view(
644
+ bsz, q_len, self.num_heads, self.head_dim
645
+ ).transpose(1, 2)
646
+ key_states = key_states.view(
647
+ bsz, q_len, self.num_key_value_heads, self.head_dim
648
+ ).transpose(1, 2)
649
+ value_states = value_states.view(
650
+ bsz, q_len, self.num_key_value_heads, self.head_dim
651
+ ).transpose(1, 2)
652
+
653
+ cos, sin = position_embeddings
654
+
655
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
656
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
657
+ )
658
+
659
+ if past_key_value is not None:
660
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
661
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
662
+ key_states, value_states = past_key_value.update(
663
+ key_states, value_states, self.layer_idx, cache_kwargs
664
+ )
665
+
666
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
667
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
668
+
669
+ causal_mask = attention_mask
670
+ # if attention_mask is not None and cache_position is not None:
671
+ if attention_mask is not None:
672
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
673
+
674
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
675
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
676
+ if query_states.device.type == "cuda" and causal_mask is not None:
677
+ query_states = query_states.contiguous()
678
+ key_states = key_states.contiguous()
679
+ value_states = value_states.contiguous()
680
+
681
+ attn_output = F.scaled_dot_product_attention(
682
+ query_states,
683
+ key_states,
684
+ value_states,
685
+ attn_mask=causal_mask,
686
+ dropout_p=self.attention_dropout if self.training else 0.0,
687
+ )
688
+
689
+ attn_output = attn_output.transpose(1, 2).contiguous()
690
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
691
+
692
+ attn_output = self.o_proj(attn_output)
693
+
694
+ return attn_output, None, past_key_value
695
+
696
+
697
+ FLMAUDIO_ATTENTION_CLASSES = {
698
+ "eager": FLMAudioAttention,
699
+ "flash_attention_2": FLMAudioFlashAttention2,
700
+ "sdpa": FLMAudioSdpaAttention,
701
+ }
702
+
703
+
704
+ class FLMAudioDecoderLayer(nn.Module):
705
+ def __init__(self, config: FLMAudioConfig, layer_idx: int):
706
+ super().__init__()
707
+ self.hidden_size = config.hidden_size
708
+ self.self_attn = FLMAUDIO_ATTENTION_CLASSES.get(
709
+ config._attn_implementation, FLMAudioAttention
710
+ )(config=config, layer_idx=layer_idx)
711
+ self.mlp = FLMAudioMLP(config)
712
+ self.input_layernorm = FLMAudioRMSNorm(
713
+ config.hidden_size, eps=config.rms_norm_eps
714
+ )
715
+ self.post_attention_layernorm = FLMAudioRMSNorm(
716
+ config.hidden_size, eps=config.rms_norm_eps
717
+ )
718
+
719
+ def forward(
720
+ self,
721
+ hidden_states: torch.Tensor,
722
+ attention_mask: Optional[torch.Tensor] = None,
723
+ position_ids: Optional[torch.LongTensor] = None,
724
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
725
+ output_attentions: Optional[bool] = False,
726
+ use_cache: Optional[bool] = False,
727
+ cache_position: Optional[torch.LongTensor] = None,
728
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
729
+ **kwargs,
730
+ ) -> Tuple[
731
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
732
+ ]:
733
+ """
734
+ Args:
735
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
736
+ attention_mask (`torch.FloatTensor`, *optional*):
737
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
738
+ query_sequence_length, key_sequence_length)` if default attention is used.
739
+ output_attentions (`bool`, *optional*):
740
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
741
+ returned tensors for more detail.
742
+ use_cache (`bool`, *optional*):
743
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
744
+ (see `past_key_values`).
745
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
746
+ """
747
+ if "padding_mask" in kwargs:
748
+ warnings.warn(
749
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
750
+ )
751
+
752
+ residual = hidden_states
753
+
754
+ hidden_states = self.input_layernorm(hidden_states)
755
+
756
+ # Self Attention
757
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
758
+ hidden_states=hidden_states,
759
+ attention_mask=attention_mask,
760
+ position_ids=position_ids,
761
+ past_key_value=past_key_value,
762
+ output_attentions=output_attentions,
763
+ use_cache=use_cache,
764
+ cache_position=cache_position,
765
+ position_embeddings=position_embeddings,
766
+ **kwargs,
767
+ )
768
+ hidden_states = residual + hidden_states
769
+
770
+ # Fully Connected
771
+ residual = hidden_states
772
+ hidden_states = self.post_attention_layernorm(hidden_states)
773
+ hidden_states = self.mlp(hidden_states)
774
+ hidden_states = residual + hidden_states
775
+
776
+ outputs = (hidden_states,)
777
+
778
+ if output_attentions:
779
+ outputs += (self_attn_weights,)
780
+
781
+ if use_cache:
782
+ outputs += (present_key_value,)
783
+
784
+ return outputs
785
+
786
+
787
+ FLMAUDIO_START_DOCSTRING = r"""
788
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
789
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
790
+ etc.)
791
+
792
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
793
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
794
+ and behavior.
795
+
796
+ Parameters:
797
+ config ([`FLMAudioConfig`]):
798
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
799
+ load the weights associated with the model, only the configuration. Check out the
800
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
801
+ """
802
+
803
+
804
+ @add_start_docstrings(
805
+ "The bare FLM-Audio Model outputting raw hidden-states without any specific head on top.",
806
+ FLMAUDIO_START_DOCSTRING,
807
+ )
808
+ class FLMAudioPreTrainedModel(PreTrainedModel):
809
+ config_class = FLMAudioConfig
810
+ base_model_prefix = "model"
811
+ supports_gradient_checkpointing = True
812
+ _no_split_modules = ["FLMAudioDecoderLayer"]
813
+ _skip_keys_device_placement = ["past_key_values"]
814
+ _supports_flash_attn_2 = True
815
+ _supports_sdpa = True
816
+ _supports_cache_class = True
817
+
818
+ def _init_weights(self, module):
819
+ std = self.config.initializer_range
820
+ if isinstance(module, nn.Linear):
821
+ module.weight.data.normal_(mean=0.0, std=std)
822
+ if module.bias is not None:
823
+ module.bias.data.zero_()
824
+ elif isinstance(module, nn.Embedding):
825
+ module.weight.data.normal_(mean=0.0, std=std)
826
+ if module.padding_idx is not None:
827
+ module.weight.data[module.padding_idx].zero_()
828
+
829
+ def _setup_cache(
830
+ self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None
831
+ ):
832
+ if (
833
+ self.config._attn_implementation == "flash_attention_2"
834
+ and cache_cls == StaticCache
835
+ ):
836
+ raise ValueError(
837
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
838
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
839
+ )
840
+
841
+ for layer in self.model.layers:
842
+ device = layer.input_layernorm.weight.device
843
+ if hasattr(self.config, "_pre_quantization_dtype"):
844
+ dtype = self.config._pre_quantization_dtype
845
+ else:
846
+ dtype = layer.self_attn.o_proj.weight.dtype
847
+ layer.self_attn.past_key_value = cache_cls(
848
+ self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
849
+ )
850
+
851
+ def _reset_cache(self):
852
+ for layer in self.model.layers:
853
+ layer.self_attn.past_key_value = None
854
+
855
+
856
+ class MultiModalEmbedding(nn.Module):
857
+ def __init__(self, config):
858
+ super().__init__()
859
+ self.config = config
860
+ self.use_mup = config.use_mup
861
+ self.input_mult = config.input_mult
862
+ self.hidden_size = config.hidden_size
863
+
864
+ self.vocab_size = config.vocab_size
865
+ self.aud_vocab_size = config.aud_vocab_size
866
+
867
+ self.aud_channel = config.aud_channel
868
+
869
+ self.aud_emp_token_id = config.mm_token_info.aud_emp_token_id
870
+
871
+ self.text_embeddings = nn.Embedding(self.vocab_size, self.hidden_size)
872
+
873
+ self.aud_listen_embeddings = nn.ModuleList(
874
+ [
875
+ nn.Embedding(self.aud_vocab_size, self.hidden_size)
876
+ for _ in range(self.aud_channel)
877
+ ]
878
+ )
879
+ self.aud_speak_embeddings = nn.ModuleList(
880
+ [
881
+ nn.Embedding(self.aud_vocab_size, self.hidden_size)
882
+ for _ in range(self.aud_channel)
883
+ ]
884
+ )
885
+
886
+ @staticmethod
887
+ def merge_multichannel_embeddings(
888
+ token_ids, embedding_layer, emp_token_id, embeddings
889
+ ):
890
+ if token_ids is not None and embedding_layer is not None:
891
+ assert token_ids.shape[2] == len(embedding_layer)
892
+ for c in range(token_ids.shape[2]):
893
+ _emb_state = embedding_layer[c](token_ids[:, :, c])
894
+ _emb_state[token_ids[:, :, c] == emp_token_id] = 0.0
895
+ embeddings += _emb_state
896
+ _emb_state = None
897
+ del _emb_state
898
+ return embeddings
899
+
900
+ def forward(
901
+ self,
902
+ text_ids,
903
+ speak_ids,
904
+ listen_ids,
905
+ ):
906
+ assert text_ids is not None
907
+ embeddings = self.text_embeddings(text_ids)
908
+ mask = ~(text_ids == self.config.pad_token_id)
909
+
910
+ for aud_chn_idx in range(self.aud_channel):
911
+ aud_speak_embed = self.aud_speak_embeddings[aud_chn_idx](
912
+ speak_ids[..., aud_chn_idx]
913
+ ).squeeze(0)
914
+ aud_listen_embed = self.aud_listen_embeddings[aud_chn_idx](
915
+ listen_ids[..., aud_chn_idx]
916
+ ).squeeze(0)
917
+ embeddings[mask] += aud_speak_embed + aud_listen_embed
918
+
919
+ if self.use_mup:
920
+ embeddings = embeddings * self.input_mult
921
+
922
+ return embeddings
923
+
924
+
925
+ FLMAUDIO_INPUTS_DOCSTRING = r"""
926
+ Args:
927
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
928
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
929
+ it.
930
+
931
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
932
+ [`PreTrainedTokenizer.__call__`] for details.
933
+
934
+ [What are input IDs?](../glossary#input-ids)
935
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
936
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
937
+
938
+ - 1 for tokens that are **not masked**,
939
+ - 0 for tokens that are **masked**.
940
+
941
+ [What are attention masks?](../glossary#attention-mask)
942
+
943
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
944
+ [`PreTrainedTokenizer.__call__`] for details.
945
+
946
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
947
+ `past_key_values`).
948
+
949
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
950
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
951
+ information on the default strategy.
952
+
953
+ - 1 indicates the head is **not masked**,
954
+ - 0 indicates the head is **masked**.
955
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
956
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
957
+ config.n_positions - 1]`.
958
+
959
+ [What are position IDs?](../glossary#position-ids)
960
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
961
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
962
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
963
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
964
+
965
+ Two formats are allowed:
966
+ - a [`~cache_utils.Cache`] instance;
967
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
968
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
969
+ cache format.
970
+
971
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
972
+ legacy cache format will be returned.
973
+
974
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
975
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
976
+ of shape `(batch_size, sequence_length)`.
977
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
978
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
979
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
980
+ model's internal embedding lookup matrix.
981
+ use_cache (`bool`, *optional*):
982
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
983
+ `past_key_values`).
984
+ output_attentions (`bool`, *optional*):
985
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
986
+ tensors for more detail.
987
+ output_hidden_states (`bool`, *optional*):
988
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
989
+ more detail.
990
+ return_dict (`bool`, *optional*):
991
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
992
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
993
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
994
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
995
+ the complete sequence length.
996
+ """
997
+
998
+
999
+ @add_start_docstrings(
1000
+ "The bare FLM-Audio Model outputting raw hidden-states without any specific head on top.",
1001
+ FLMAUDIO_START_DOCSTRING,
1002
+ )
1003
+ class FLMAudioModel(FLMAudioPreTrainedModel):
1004
+ """
1005
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`FLMAudioDecoderLayer`]
1006
+
1007
+ Args:
1008
+ config: FLMAudioConfig
1009
+ """
1010
+
1011
+ def __init__(self, config: FLMAudioConfig):
1012
+ super().__init__(config)
1013
+ self.padding_idx = config.pad_token_id
1014
+ self.vocab_size = config.vocab_size
1015
+
1016
+ self.embed_tokens = MultiModalEmbedding(config)
1017
+ self.layers = nn.ModuleList(
1018
+ [
1019
+ FLMAudioDecoderLayer(config, layer_idx)
1020
+ for layer_idx in range(config.num_hidden_layers)
1021
+ ]
1022
+ )
1023
+ self.norm = FLMAudioRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1024
+ self.rotary_emb = FLMAudioRotaryEmbedding(config=config)
1025
+ self.gradient_checkpointing = False
1026
+ self.rope_deltas = None # cache rope_deltas here
1027
+
1028
+ # Initialize weights and apply final processing
1029
+ self.post_init()
1030
+
1031
+ def get_input_embeddings(self) -> MultiModalEmbedding:
1032
+ return self.embed_tokens
1033
+
1034
+ def set_input_embeddings(self, value: MultiModalEmbedding):
1035
+ self.embed_tokens = value
1036
+
1037
+ def get_rope_index(
1038
+ self,
1039
+ input_ids: Optional[torch.LongTensor] = None,
1040
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1041
+ attention_mask: Optional[torch.Tensor] = None,
1042
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1043
+
1044
+ mrope_position_deltas = []
1045
+
1046
+ if attention_mask is not None:
1047
+ position_ids = attention_mask.long().cumsum(-1) - 1
1048
+ position_ids.masked_fill_(attention_mask == 0, 1)
1049
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
1050
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1051
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1052
+ else:
1053
+ position_ids = (
1054
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1055
+ .view(1, 1, -1)
1056
+ .expand(3, input_ids.shape[0], -1)
1057
+ )
1058
+ mrope_position_deltas = torch.zeros(
1059
+ [input_ids.shape[0], 1],
1060
+ device=input_ids.device,
1061
+ dtype=input_ids.dtype,
1062
+ )
1063
+
1064
+ return position_ids, mrope_position_deltas
1065
+
1066
+
1067
+ @add_start_docstrings_to_model_forward(FLMAUDIO_INPUTS_DOCSTRING)
1068
+ def forward(
1069
+ self,
1070
+ text_ids: torch.LongTensor = None,
1071
+ listen_ids: torch.LongTensor = None,
1072
+ speak_ids: torch.LongTensor = None,
1073
+ attention_mask: Optional[torch.Tensor] = None,
1074
+ position_ids: Optional[torch.LongTensor] = None,
1075
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1076
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1077
+ use_cache: Optional[bool] = None,
1078
+ output_attentions: Optional[bool] = None,
1079
+ output_hidden_states: Optional[bool] = None,
1080
+ return_dict: Optional[bool] = None,
1081
+ rope_deltas: Optional[torch.LongTensor] = None,
1082
+ cache_position: Optional[torch.LongTensor] = None,
1083
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1084
+ **kwargs,
1085
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1086
+ output_attentions = (
1087
+ output_attentions
1088
+ if output_attentions is not None
1089
+ else self.config.output_attentions
1090
+ )
1091
+ output_hidden_states = (
1092
+ output_hidden_states
1093
+ if output_hidden_states is not None
1094
+ else self.config.output_hidden_states
1095
+ )
1096
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1097
+ return_dict = (
1098
+ return_dict if return_dict is not None else self.config.use_return_dict
1099
+ )
1100
+
1101
+ if (text_ids is None) ^ (inputs_embeds is not None):
1102
+ raise ValueError(
1103
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1104
+ )
1105
+
1106
+ if self.gradient_checkpointing and self.training and use_cache:
1107
+ logger.warning_once(
1108
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1109
+ )
1110
+ use_cache = False
1111
+
1112
+ if inputs_embeds is None:
1113
+ inputs_embeds = self.embed_tokens(
1114
+ text_ids,
1115
+ speak_ids,
1116
+ listen_ids,
1117
+ )
1118
+
1119
+ past_seen_tokens = 0
1120
+ if use_cache: # kept for BC (cache positions)
1121
+ if not isinstance(past_key_values, StaticCache):
1122
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1123
+ past_seen_tokens = past_key_values.get_seq_length()
1124
+
1125
+ if cache_position is None:
1126
+ if isinstance(past_key_values, StaticCache):
1127
+ raise ValueError(
1128
+ "cache_position is a required argument when using StaticCache."
1129
+ )
1130
+ cache_position = torch.arange(
1131
+ past_seen_tokens,
1132
+ past_seen_tokens + inputs_embeds.shape[1],
1133
+ device=inputs_embeds.device,
1134
+ )
1135
+
1136
+ # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
1137
+ if position_ids is None:
1138
+ # calculate RoPE index once per generation in the pre-fill stage only
1139
+ if (
1140
+ (cache_position is not None and cache_position[0] == 0)
1141
+ or self.rope_deltas is None
1142
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
1143
+ ):
1144
+ position_ids, rope_deltas = self.get_rope_index(
1145
+ text_ids,
1146
+ second_per_grid_ts,
1147
+ attention_mask,
1148
+ )
1149
+ self.rope_deltas = rope_deltas
1150
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1151
+ else:
1152
+ batch_size, seq_length, _ = inputs_embeds.shape
1153
+ delta = (
1154
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1155
+ if cache_position is not None
1156
+ else 0
1157
+ )
1158
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1159
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1160
+ if cache_position is not None: # otherwise `deltas` is an int `0`
1161
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1162
+ position_ids = position_ids.add(delta)
1163
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1164
+
1165
+ causal_mask = self._update_causal_mask(
1166
+ attention_mask, inputs_embeds, cache_position
1167
+ )
1168
+
1169
+ # embed positions
1170
+ hidden_states = inputs_embeds
1171
+
1172
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1173
+
1174
+ # decoder layers
1175
+ all_hidden_states = () if output_hidden_states else None
1176
+ all_self_attns = () if output_attentions else None
1177
+ next_decoder_cache = None
1178
+
1179
+ for decoder_layer in self.layers:
1180
+ if output_hidden_states:
1181
+ all_hidden_states += (hidden_states,)
1182
+
1183
+ if self.gradient_checkpointing and self.training:
1184
+ layer_outputs = self._gradient_checkpointing_func(
1185
+ decoder_layer.__call__,
1186
+ hidden_states,
1187
+ causal_mask,
1188
+ position_ids,
1189
+ past_key_values,
1190
+ output_attentions,
1191
+ use_cache,
1192
+ cache_position,
1193
+ position_embeddings,
1194
+ )
1195
+ else:
1196
+ layer_outputs = decoder_layer(
1197
+ hidden_states,
1198
+ attention_mask=causal_mask,
1199
+ position_ids=position_ids,
1200
+ past_key_value=past_key_values,
1201
+ output_attentions=output_attentions,
1202
+ use_cache=use_cache,
1203
+ cache_position=cache_position,
1204
+ position_embeddings=position_embeddings,
1205
+ )
1206
+
1207
+ hidden_states = layer_outputs[0]
1208
+
1209
+ if use_cache:
1210
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1211
+
1212
+ if output_attentions:
1213
+ all_self_attns += (layer_outputs[1],)
1214
+
1215
+ hidden_states = self.norm(hidden_states)
1216
+
1217
+ # add hidden states from the last decoder layer
1218
+ if output_hidden_states:
1219
+ all_hidden_states += (hidden_states,)
1220
+
1221
+ next_cache = None
1222
+ if use_cache:
1223
+ next_cache = (
1224
+ next_decoder_cache.to_legacy_cache()
1225
+ if isinstance(next_decoder_cache, Cache)
1226
+ else next_decoder_cache
1227
+ )
1228
+ if not return_dict:
1229
+ return tuple(
1230
+ v
1231
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1232
+ if v is not None
1233
+ )
1234
+ return BaseModelOutputWithPast(
1235
+ last_hidden_state=hidden_states,
1236
+ past_key_values=next_cache,
1237
+ hidden_states=all_hidden_states,
1238
+ attentions=all_self_attns,
1239
+ )
1240
+
1241
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1242
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1243
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1244
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1245
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
1246
+ if self.config._attn_implementation == "flash_attention_2":
1247
+ if attention_mask is not None and 0.0 in attention_mask:
1248
+ return attention_mask
1249
+ return None
1250
+
1251
+ dtype, device = input_tensor.dtype, input_tensor.device
1252
+ min_dtype = torch.finfo(dtype).min
1253
+ sequence_length = input_tensor.shape[1]
1254
+ if hasattr(
1255
+ getattr(self.layers[0], "self_attn", {}), "past_key_value"
1256
+ ): # static cache
1257
+ target_length = self.config.max_position_embeddings
1258
+ else: # dynamic cache
1259
+ target_length = (
1260
+ attention_mask.shape[-1]
1261
+ if isinstance(attention_mask, torch.Tensor)
1262
+ else cache_position[-1] + 1
1263
+ )
1264
+
1265
+ causal_mask = torch.full(
1266
+ (sequence_length, target_length),
1267
+ fill_value=min_dtype,
1268
+ dtype=dtype,
1269
+ device=device,
1270
+ )
1271
+ if sequence_length != 1:
1272
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1273
+ causal_mask *= torch.arange(
1274
+ target_length, device=device
1275
+ ) > cache_position.reshape(-1, 1)
1276
+ causal_mask = causal_mask[None, None, :, :].expand(
1277
+ input_tensor.shape[0], 1, -1, -1
1278
+ )
1279
+ if attention_mask is not None:
1280
+ causal_mask = (
1281
+ causal_mask.clone()
1282
+ ) # copy to contiguous memory for in-place edit
1283
+ if attention_mask.dim() == 2:
1284
+ mask_length = attention_mask.shape[-1]
1285
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
1286
+ :, None, None, :
1287
+ ].eq(0.0)
1288
+ causal_mask[..., :mask_length] = causal_mask[
1289
+ ..., :mask_length
1290
+ ].masked_fill(padding_mask, min_dtype)
1291
+ elif attention_mask.dim() == 4:
1292
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1293
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
1294
+ if attention_mask.shape[-2] < cache_position[0] + sequence_length:
1295
+ offset = cache_position[0]
1296
+ else:
1297
+ offset = 0
1298
+ mask_shape = attention_mask.shape
1299
+ mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1300
+ causal_mask[
1301
+ : mask_shape[0],
1302
+ : mask_shape[1],
1303
+ offset : mask_shape[2] + offset,
1304
+ : mask_shape[3],
1305
+ ] = mask_slice
1306
+
1307
+ if (
1308
+ self.config._attn_implementation == "sdpa"
1309
+ and attention_mask is not None
1310
+ and attention_mask.device.type == "cuda"
1311
+ ):
1312
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1313
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1314
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1315
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1316
+ causal_mask, min_dtype
1317
+ )
1318
+
1319
+ return causal_mask
1320
+
1321
+
1322
+ @dataclass
1323
+ class FLMAudioCausalLMOutputWithPast(ModelOutput):
1324
+ loss: Optional[torch.FloatTensor] = None
1325
+ logits: torch.FloatTensor = None
1326
+ audio_logits: torch.FloatTensor = None
1327
+ past_key_values: Optional[List[torch.FloatTensor]] = None
1328
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1329
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1330
+ rope_deltas: Optional[torch.LongTensor] = None
1331
+
1332
+
1333
+ class FLMAudioForCausalLM(FLMAudioPreTrainedModel):
1334
+ _tied_weights_keys = ["lm_head.weight"]
1335
+
1336
+ def __init__(self, config):
1337
+ super().__init__(config)
1338
+ self.model = FLMAudioModel(config)
1339
+ self.vocab_size = config.vocab_size
1340
+ self.output_mult = config.output_mult
1341
+
1342
+ depth_config = DepthGPTConfig(
1343
+ block_size=config.aud_channel,
1344
+ vocab_size=config.aud_vocab_size,
1345
+ n_layer=config.aud_depthgpt.n_layer,
1346
+ n_head=config.aud_depthgpt.n_head,
1347
+ n_embd=config.aud_depthgpt.n_embd,
1348
+ dropout=config.aud_depthgpt.dropout,
1349
+ bias=config.aud_depthgpt.bias,
1350
+ main_hidden_size=config.hidden_size,
1351
+ pad_token_id=config.mm_token_info.aud_emp_token_id,
1352
+ use_cmlp=config.aud_depthgpt.use_cmlp,
1353
+ use_rmsnorm=config.aud_depthgpt.use_rmsnorm,
1354
+ use_swiglu=config.aud_depthgpt.use_swiglu,
1355
+ )
1356
+
1357
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1358
+
1359
+ self.aud_output_layers = DepthGPT(depth_config)
1360
+
1361
+ self.use_mup = config.use_mup
1362
+ if self.use_mup:
1363
+ self.mup_scale_factor = config.mup_scale_factor
1364
+ self.output_mult = config.output_mult / self.mup_scale_factor
1365
+ # Initialize weights and apply final processing
1366
+ self.post_init()
1367
+
1368
+ def get_input_embeddings(self):
1369
+ return self.model.embed_tokens
1370
+
1371
+ def set_input_embeddings(self, value):
1372
+ self.model.embed_tokens = value
1373
+
1374
+ def get_output_embeddings(self):
1375
+ return self.lm_head
1376
+
1377
+ def set_output_embeddings(self, new_embeddings):
1378
+ self.lm_head = new_embeddings
1379
+
1380
+ def set_decoder(self, decoder):
1381
+ self.model = decoder
1382
+
1383
+ def get_decoder(self):
1384
+ return self.model
1385
+
1386
+ def _forward_text(self, outputs, labels, return_dict):
1387
+
1388
+ logits = self.lm_head(outputs[0])
1389
+ logits = logits.float()
1390
+ # Mup
1391
+ if self.use_mup:
1392
+ logits = logits * self.output_mult
1393
+
1394
+ loss = None
1395
+ if labels is not None:
1396
+ raise NotImplementedError
1397
+
1398
+ if not return_dict:
1399
+ output = (logits,) + outputs[1:]
1400
+ return (loss,) + output if loss is not None else output
1401
+
1402
+ return FLMAudioCausalLMOutputWithPast(
1403
+ loss=loss,
1404
+ logits=logits,
1405
+ past_key_values=outputs.past_key_values,
1406
+ hidden_states=outputs.last_hidden_state,
1407
+ attentions=outputs.attentions,
1408
+ )
1409
+
1410
+ def forward_audio(self, transformer_output_states, audio_input_ids):
1411
+ return self.aud_output_layers(transformer_output_states, audio_input_ids)
1412
+
1413
+ @add_start_docstrings_to_model_forward(FLMAUDIO_INPUTS_DOCSTRING)
1414
+ @replace_return_docstrings(
1415
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1416
+ )
1417
+ def forward(
1418
+ self,
1419
+ input_ids: torch.LongTensor = None,
1420
+ listen_ids: torch.LongTensor = None,
1421
+ speak_ids: torch.LongTensor = None,
1422
+ attention_mask: Optional[torch.Tensor] = None,
1423
+ position_ids: Optional[torch.LongTensor] = None,
1424
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1425
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1426
+ labels: Optional[torch.LongTensor] = None,
1427
+ use_cache: Optional[bool] = None,
1428
+ output_attentions: Optional[bool] = None,
1429
+ output_hidden_states: Optional[bool] = None,
1430
+ return_dict: Optional[bool] = None,
1431
+ rope_deltas: Optional[torch.LongTensor] = None,
1432
+ cache_position: Optional[torch.LongTensor] = None,
1433
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1434
+ **kwargs,
1435
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1436
+ r"""
1437
+ Args:
1438
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1439
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1440
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1441
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1442
+
1443
+ Returns:
1444
+
1445
+ Example:
1446
+
1447
+ ```python
1448
+ >>> from transformers import AutoTokenizer, FLMAudioForCausalLM
1449
+
1450
+ >>> model = FLMAudioForCausalLM.from_pretrained("CofeAI/FLM-Audio")
1451
+ >>> tokenizer = AutoTokenizer.from_pretrained("CofeAI/FLM-Audio")
1452
+
1453
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1454
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1455
+
1456
+ >>> # Generate
1457
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1458
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1459
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1460
+ ```"""
1461
+ output_attentions = (
1462
+ output_attentions
1463
+ if output_attentions is not None
1464
+ else self.config.output_attentions
1465
+ )
1466
+ output_hidden_states = (
1467
+ output_hidden_states
1468
+ if output_hidden_states is not None
1469
+ else self.config.output_hidden_states
1470
+ )
1471
+ return_dict = (
1472
+ return_dict if return_dict is not None else self.config.use_return_dict
1473
+ )
1474
+
1475
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1476
+
1477
+ if listen_ids is None and speak_ids is None:
1478
+ batch_size, seq_len = input_ids.shape[:2]
1479
+ listen_ids = torch.full((seq_len*batch_size, 8), self.model.config.mm_token_info.aud_pad_token_id, device=input_ids.device, dtype=input_ids.dtype)
1480
+ speak_ids = torch.full((seq_len*batch_size, 8), self.model.config.mm_token_info.aud_pad_token_id, device=input_ids.device, dtype=input_ids.dtype)
1481
+ outputs = self.model(
1482
+ text_ids=input_ids,
1483
+ listen_ids=listen_ids,
1484
+ speak_ids=speak_ids,
1485
+ attention_mask=attention_mask,
1486
+ position_ids=position_ids,
1487
+ past_key_values=past_key_values,
1488
+ inputs_embeds=inputs_embeds,
1489
+ use_cache=use_cache,
1490
+ output_attentions=output_attentions,
1491
+ output_hidden_states=output_hidden_states,
1492
+ return_dict=return_dict,
1493
+ cache_position=cache_position,
1494
+ second_per_grid_ts=second_per_grid_ts
1495
+ )
1496
+ return self._forward_text(outputs, labels, return_dict)
1497
+
1498
+ @staticmethod
1499
+ def _reorder_cache(past_key_values, beam_idx):
1500
+ reordered_past = ()
1501
+ for layer_past in past_key_values:
1502
+ reordered_past += (
1503
+ tuple(
1504
+ past_state.index_select(0, beam_idx.to(past_state.device))
1505
+ for past_state in layer_past
1506
+ ),
1507
+ )
1508
+ return reordered_past
1509
+
1510
+ def _get_initial_token(self) -> torch.Tensor:
1511
+ # Returns the initial token that will be fed to the model to predict the very first timestep.
1512
+ # The output shape will be [B, K, 1].
1513
+ device = next(iter(self.parameters())).device
1514
+ zero = torch.full([1, 1, 1], 0, device=device, dtype=torch.long)
1515
+ special = torch.full_like(zero, self.config.mm_token_info.aud_pad_token_id)
1516
+
1517
+ text_special = torch.full_like(
1518
+ zero, self.config.mm_token_info.text_wait_token_id
1519
+ )
1520
+ audio_token = special
1521
+ text_token = text_special
1522
+ audio_token = audio_token.expand(-1, 2 * self.config.aud_channel, -1).clone()
1523
+ token = torch.cat([text_token, audio_token], dim=1)
1524
+ return token
special_tokens_map.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>",
16
+ "<|text_wait|>",
17
+ "<|asr|>",
18
+ "<|answer|>"
19
+ ],
20
+ "eos_token": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ },
27
+ "pad_token": {
28
+ "content": "<|endoftext|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
tokenizer-e351c8d8-checkpoint125.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09b782f0629851a271227fb9d36db65c041790365f11bbe5d3d59369cf863f50
3
+ size 384644900
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:805c30668219574cf25bb2e2d361be36f68910bf290bd574ba9bc4e73150169a
3
+ size 11422457
tokenizer_config.json ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<|text_wait|>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ },
189
+ "151666": {
190
+ "content": "<|asr|>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": true
196
+ },
197
+ "151667": {
198
+ "content": "<|answer|>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": true
204
+ }
205
+ },
206
+ "additional_special_tokens": [
207
+ "<|im_start|>",
208
+ "<|im_end|>",
209
+ "<|object_ref_start|>",
210
+ "<|object_ref_end|>",
211
+ "<|box_start|>",
212
+ "<|box_end|>",
213
+ "<|quad_start|>",
214
+ "<|quad_end|>",
215
+ "<|vision_start|>",
216
+ "<|vision_end|>",
217
+ "<|vision_pad|>",
218
+ "<|image_pad|>",
219
+ "<|video_pad|>",
220
+ "<|text_wait|>",
221
+ "<|asr|>",
222
+ "<|answer|>"
223
+ ],
224
+ "bos_token": null,
225
+ "clean_up_tokenization_spaces": false,
226
+ "eos_token": "<|im_end|>",
227
+ "errors": "replace",
228
+ "extra_special_tokens": {},
229
+ "model_max_length": 131072,
230
+ "pad_token": "<|endoftext|>",
231
+ "split_special_tokens": false,
232
+ "tokenizer_class": "Qwen2Tokenizer",
233
+ "unk_token": null
234
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff