ThomasTheMaker commited on
Commit
9f8789c
·
verified ·
1 Parent(s): 80a9a5e

Delete pico-decoder-tiny-dolma29k

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. pico-decoder-tiny-dolma29k/checkpoints/step_0/config.json +0 -22
  2. pico-decoder-tiny-dolma29k/checkpoints/step_0/fabric_state/checkpoint.pt +0 -3
  3. pico-decoder-tiny-dolma29k/checkpoints/step_0/generation_config.json +0 -4
  4. pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_activations.pt +0 -3
  5. pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_data/data-00000-of-00001.arrow +0 -3
  6. pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_data/dataset_info.json +0 -19
  7. pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_data/state.json +0 -13
  8. pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_gradients.pt +0 -3
  9. pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_weights.pt +0 -3
  10. pico-decoder-tiny-dolma29k/checkpoints/step_0/model.safetensors +0 -3
  11. pico-decoder-tiny-dolma29k/checkpoints/step_0/pico_decoder.py +0 -871
  12. pico-decoder-tiny-dolma29k/checkpoints/step_0/special_tokens_map.json +0 -16
  13. pico-decoder-tiny-dolma29k/checkpoints/step_0/tokenizer.json +0 -0
  14. pico-decoder-tiny-dolma29k/checkpoints/step_0/tokenizer_config.json +0 -239
  15. pico-decoder-tiny-dolma29k/checkpoints/step_1000/config.json +0 -22
  16. pico-decoder-tiny-dolma29k/checkpoints/step_1000/fabric_state/checkpoint.pt +0 -3
  17. pico-decoder-tiny-dolma29k/checkpoints/step_1000/generation_config.json +0 -4
  18. pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_activations.pt +0 -3
  19. pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_data/data-00000-of-00001.arrow +0 -3
  20. pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_data/dataset_info.json +0 -19
  21. pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_data/state.json +0 -13
  22. pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_gradients.pt +0 -3
  23. pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_weights.pt +0 -3
  24. pico-decoder-tiny-dolma29k/checkpoints/step_1000/model.safetensors +0 -3
  25. pico-decoder-tiny-dolma29k/checkpoints/step_1000/pico_decoder.py +0 -871
  26. pico-decoder-tiny-dolma29k/checkpoints/step_1000/special_tokens_map.json +0 -16
  27. pico-decoder-tiny-dolma29k/checkpoints/step_1000/tokenizer.json +0 -0
  28. pico-decoder-tiny-dolma29k/checkpoints/step_1000/tokenizer_config.json +0 -239
  29. pico-decoder-tiny-dolma29k/checkpoints/step_2000/config.json +0 -22
  30. pico-decoder-tiny-dolma29k/checkpoints/step_2000/fabric_state/checkpoint.pt +0 -3
  31. pico-decoder-tiny-dolma29k/checkpoints/step_2000/generation_config.json +0 -4
  32. pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_activations.pt +0 -3
  33. pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_data/data-00000-of-00001.arrow +0 -3
  34. pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_data/dataset_info.json +0 -19
  35. pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_data/state.json +0 -13
  36. pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_gradients.pt +0 -3
  37. pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_weights.pt +0 -3
  38. pico-decoder-tiny-dolma29k/checkpoints/step_2000/model.safetensors +0 -3
  39. pico-decoder-tiny-dolma29k/checkpoints/step_2000/pico_decoder.py +0 -871
  40. pico-decoder-tiny-dolma29k/checkpoints/step_2000/special_tokens_map.json +0 -16
  41. pico-decoder-tiny-dolma29k/checkpoints/step_2000/tokenizer.json +0 -0
  42. pico-decoder-tiny-dolma29k/checkpoints/step_2000/tokenizer_config.json +0 -239
  43. pico-decoder-tiny-dolma29k/checkpoints/step_3000/config.json +0 -22
  44. pico-decoder-tiny-dolma29k/checkpoints/step_3000/fabric_state/checkpoint.pt +0 -3
  45. pico-decoder-tiny-dolma29k/checkpoints/step_3000/generation_config.json +0 -4
  46. pico-decoder-tiny-dolma29k/checkpoints/step_3000/learning_dynamics/train_activations.pt +0 -3
  47. pico-decoder-tiny-dolma29k/checkpoints/step_3000/learning_dynamics/train_data/data-00000-of-00001.arrow +0 -3
  48. pico-decoder-tiny-dolma29k/checkpoints/step_3000/learning_dynamics/train_data/dataset_info.json +0 -19
  49. pico-decoder-tiny-dolma29k/checkpoints/step_3000/learning_dynamics/train_data/state.json +0 -13
  50. pico-decoder-tiny-dolma29k/checkpoints/step_3000/learning_dynamics/train_gradients.pt +0 -3
pico-decoder-tiny-dolma29k/checkpoints/step_0/config.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "activation_hidden_dim": 384,
3
- "architectures": [
4
- "PicoDecoderHF"
5
- ],
6
- "attention_n_heads": 12,
7
- "attention_n_kv_heads": 4,
8
- "auto_map": {
9
- "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
- "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
- },
12
- "batch_size": 1024,
13
- "d_model": 96,
14
- "max_seq_len": 2048,
15
- "model_type": "pico_decoder",
16
- "n_layers": 12,
17
- "norm_eps": 1e-06,
18
- "position_emb_theta": 10000.0,
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.48.3",
21
- "vocab_size": 50304
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/fabric_state/checkpoint.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e128c1959fae7aaab5deafca1e0ec75c66cc544622e116b0ed209a8709b03ae1
3
- size 45187997
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/generation_config.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "transformers_version": "4.48.3",
3
- "vocab_size": 50304
4
- }
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_activations.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9f02d57ee3f9eda5191db5435eb8de6a9464ba63205377de24ca75268969c58c
3
- size 33819
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_data/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:254891345b1a9d809e9c5c0a1532693b94d769025a317ad82bc418dfa3f7b40b
3
- size 71640
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_data/dataset_info.json DELETED
@@ -1,19 +0,0 @@
1
- {
2
- "citation": "",
3
- "description": "",
4
- "features": {
5
- "input_ids": {
6
- "feature": {
7
- "dtype": "int32",
8
- "_type": "Value"
9
- },
10
- "_type": "Sequence"
11
- },
12
- "text": {
13
- "dtype": "string",
14
- "_type": "Value"
15
- }
16
- },
17
- "homepage": "",
18
- "license": ""
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_data/state.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "_data_files": [
3
- {
4
- "filename": "data-00000-of-00001.arrow"
5
- }
6
- ],
7
- "_fingerprint": "3da9a89786e6494d",
8
- "_format_columns": null,
9
- "_format_kwargs": {},
10
- "_format_type": null,
11
- "_output_all_columns": false,
12
- "_split": null
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_gradients.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:37b86ac40a6afa81d719cba3f4b98a0cd62a5bb0276e410fad79405dc7c3603b
3
- size 2371527
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/learning_dynamics/train_weights.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c029ef92a6494ae121c847e432e52e6a8ff3bf7d9fef3e61bef871c1e9a9aa02
3
- size 2371443
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1852515eb5c8556533445f22edf523884b9f8cc44812379a6a951668a4ffa3a3
3
- size 45143592
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/pico_decoder.py DELETED
@@ -1,871 +0,0 @@
1
- """
2
- Pico Decoder: A Lightweight Causal Transformer Language Model
3
-
4
- Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
-
6
- Everything is written with a modular design for easy modification and experimentation.
7
-
8
- Key features:
9
- - RMSNorm for layer normalization
10
- - Rotary Positional Embeddings (RoPE)
11
- - Multi-head attention with KV-cache support
12
- - SwiGLU activation function
13
- - Residual connections throughout
14
-
15
- - KV-cache for faster autoregressive generation
16
-
17
- References:
18
- - RoPE: https://arxiv.org/abs/2104.09864
19
- - SwiGLU: https://arxiv.org/abs/2002.05202
20
- - LLAMA: https://arxiv.org/abs/2302.13971
21
-
22
- Adapted from:
23
- - OLMO: https://github.com/allenai/OLMo
24
- - LLAMA: https://github.com/meta/llama
25
- """
26
-
27
- from dataclasses import asdict
28
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
-
30
- import torch
31
- import torch.nn as nn
32
- import torch.nn.functional as F
33
- from torch.nn.attention import SDPBackend, sdpa_kernel
34
- from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
35
- from transformers.generation import GenerationConfig
36
- from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
37
-
38
- try:
39
- if TYPE_CHECKING:
40
- # We need to do this to avoid importing these when creating the HF-compatible models
41
- from src.config import ModelConfig
42
- except ImportError:
43
- pass
44
-
45
- ########################################################
46
- #
47
- # Layer Normalization
48
- #
49
- ########################################################
50
-
51
-
52
- class RMSNorm(torch.nn.Module):
53
- """Root Mean Square Layer Normalization.
54
-
55
- A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
56
- resulting in improved stability and performance.
57
-
58
- Args:
59
- config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
60
- - config.norm_eps: Small constant for numerical stability
61
- - config.d_model: Model dimension for the weight parameter
62
-
63
- References:
64
- https://arxiv.org/abs/1910.07467
65
- """
66
-
67
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
68
- super().__init__()
69
- self.eps = config.norm_eps
70
- self.weight = nn.Parameter(torch.ones(config.d_model))
71
-
72
- def _norm(self, x: torch.Tensor) -> torch.Tensor:
73
- """
74
- Normalizes the input tensor by its RMS value.
75
- """
76
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
77
-
78
- def forward(self, x: torch.Tensor) -> torch.Tensor:
79
- """
80
- Applies RMS normalization to the input tensor and scales it by the weight parameter.
81
- """
82
- output = self._norm(x.float()).type_as(x)
83
- return output * self.weight
84
-
85
-
86
- ########################################################
87
- #
88
- # Positional Embedding
89
- #
90
- ########################################################
91
-
92
-
93
- class RoPE(nn.Module):
94
- """Rotary Positional Embeddings (RoPE).
95
-
96
- Implements position-dependent rotation of keys and queries in attention mechanism,
97
- allowing better modeling of relative positions in sequences. Uses complex number
98
- operations for efficient rotation.
99
-
100
- Args:
101
- config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
102
- - config.position_emb_theta: Base for frequency computation
103
- - config.d_model: Model dimension
104
- - config.attention_n_heads: Number of attention heads
105
- - config.max_seq_len: Maximum sequence length
106
-
107
- References:
108
- https://arxiv.org/abs/2104.09864
109
- """
110
-
111
- _freqs_cis_tensor: torch.Tensor | None = None
112
-
113
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
114
- super().__init__()
115
-
116
- self.theta = config.position_emb_theta
117
- self.dim = config.d_model // config.attention_n_heads
118
-
119
- max_seq_len = config.max_seq_len
120
-
121
- # only gets set once, and then reused for all RoPE instances
122
- if RoPE._freqs_cis_tensor is None:
123
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(
124
- max_seq_len, self.theta, self.dim
125
- )
126
-
127
- # register _freqs_cis buffer
128
- # can be easily recomputed so persistent=False
129
- self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
130
-
131
- @classmethod
132
- def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
133
- """Setup Frequency Tensor for RoPE Embeddings
134
-
135
- Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
136
-
137
- Note other implementations will use cos and sin directly, but using the complex
138
- number representation is (probably) more efficient:
139
-
140
- e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
141
- """
142
- _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
143
- positions = torch.arange(seq_len)
144
- freqs = torch.outer(positions, _freqs)
145
- return torch.polar(torch.ones_like(freqs), freqs) # complex64
146
-
147
- def get_freqs_cis(
148
- self, input_shape: torch.Size, start_pos: int, end_pos: int
149
- ) -> torch.Tensor:
150
- """Reshape Frequency Tensor for RoPE Embeddings
151
-
152
- Makes the frequency tensor broadcastable with the input tensor.
153
- """
154
- _freqs_cis = self._freqs_cis[start_pos:end_pos]
155
- ndim = len(input_shape)
156
- assert 0 <= 1 < ndim
157
- assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
158
-
159
- # TODO: Check whether this is correct (might be able to remove this)
160
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
161
- return _freqs_cis.view(*shape)
162
-
163
- def forward(
164
- self,
165
- queries: torch.Tensor,
166
- keys: torch.Tensor,
167
- start_pos: int = 0,
168
- ) -> Tuple[torch.Tensor, torch.Tensor]:
169
- """Apply RoPE Embeddings to Queries and Keys
170
-
171
- Applies the rotary positional embeddings to the input tensors via complex num multiplication
172
-
173
- NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
174
- """
175
- queries_ = torch.view_as_complex(
176
- queries.float().reshape(*queries.shape[:-1], -1, 2)
177
- )
178
- keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
179
-
180
- input_shape = (
181
- queries_.shape
182
- ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
183
- freqs_start_pos = start_pos
184
- freqs_end_pos = freqs_start_pos + queries_.shape[1]
185
-
186
- freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
187
-
188
- queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
189
- keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
190
- return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
191
-
192
-
193
- ########################################################
194
- #
195
- # Attention
196
- #
197
- ########################################################
198
-
199
-
200
- class Attention(nn.Module):
201
- """Multi-head Attention with Group Query Attention support.
202
-
203
- Implements scaled dot-product attention and supports:
204
- - Grouped Query Attention (GQA)
205
- - Key-Value caching for efficient inference
206
- - RoPE integration
207
-
208
- Args:
209
- config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
210
- - config.attention_n_heads: Number of attention heads
211
- - config.attention_n_kv_heads: Number of key/value heads
212
- - config.d_model: Model dimension
213
- - config.batch_size: Maximum batch size
214
- - config.max_seq_len: Maximum sequence length
215
-
216
- Shape:
217
- - Input: (batch_size, seq_len, d_model)
218
- - Output: (batch_size, seq_len, d_model)
219
- """
220
-
221
- def __init__(
222
- self,
223
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
224
- ):
225
- super().__init__()
226
-
227
- self.n_heads = config.attention_n_heads
228
- self.n_kv_heads = config.attention_n_kv_heads
229
-
230
- self.batch_size = config.batch_size
231
- self.max_seq_len = config.max_seq_len
232
-
233
- d_model = config.d_model
234
- self.head_dim = d_model // self.n_heads
235
-
236
- self.n_rep = self.n_heads // self.n_kv_heads
237
-
238
- self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
239
- self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
240
- self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
241
- self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
242
-
243
- self.rope = RoPE(config)
244
-
245
- def forward(
246
- self,
247
- input: torch.Tensor,
248
- mask: Optional[torch.Tensor] = None,
249
- past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
250
- use_cache: bool = False,
251
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
252
- """Forward pass for the attention mechanism.
253
-
254
- Computes queries, keys, and values for the attention mechanism. Applies rotary positional
255
- embeddings to the queries and keys, and then computes attention scores and outputs.
256
-
257
- For an introduction to the attention mechanism, see:
258
- https://arxiv.org/abs/1706.03762
259
-
260
- A few things to note:
261
- - The past_key_values is used to implement the KV cache, which is used to speed up
262
- generation by caching the KV pairs from previous forward passes. This is useful when doing
263
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
264
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
265
- its own KV cache - this KV cache is implemented as a tuple.
266
- """
267
- bsz, seq_len, _ = input.shape
268
- _queries, _keys, _values = (
269
- self.q_proj(input),
270
- self.k_proj(input),
271
- self.v_proj(input),
272
- )
273
-
274
- # Reshaping for multi-head attention
275
- queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
276
- keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
277
- values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
278
-
279
- # The start position is used to apply the RoPE embeddings to only the new tokens
280
- # when using the kv_cache in the attention mechanism.
281
- # We want to start from the last position in the cache.
282
- start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
283
-
284
- # apply rotary positional embeddings
285
- queries, keys = self.rope(queries, keys, start_pos)
286
-
287
- if past_key_values is not None:
288
- keys = torch.cat([past_key_values[0], keys], dim=1)
289
- values = torch.cat([past_key_values[1], values], dim=1)
290
-
291
- if use_cache:
292
- cached_keys = keys
293
- cached_values = values
294
- else:
295
- cached_keys = None
296
- cached_values = None
297
-
298
- queries = queries.transpose(1, 2)
299
- keys = keys.transpose(1, 2)
300
- values = values.transpose(1, 2)
301
-
302
- apply_gqa = self.n_rep > 1
303
- if apply_gqa and queries.device.type == "mps":
304
- # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
305
- # outside of the kernel to get the same effect.
306
- # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
307
- keys = keys.repeat_interleave(self.n_rep, dim=-3)
308
- values = values.repeat_interleave(self.n_rep, dim=-3)
309
- apply_gqa = False
310
-
311
- backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
312
-
313
- with sdpa_kernel(backends=backends):
314
- attn_output = F.scaled_dot_product_attention(
315
- queries.contiguous(),
316
- keys.contiguous(),
317
- values.contiguous(),
318
- attn_mask=mask.to(queries.dtype) if mask is not None else None,
319
- enable_gqa=apply_gqa,
320
- )
321
-
322
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
323
- output = self.o_proj(attn_output)
324
-
325
- return output, (cached_keys, cached_values)
326
-
327
-
328
- ########################################################
329
- #
330
- # SwiGLU (Combines MLP and Activation)
331
- #
332
- ########################################################
333
-
334
-
335
- class SwiGLU(nn.Module):
336
- """SwiGLU Activation Function with Linear Projections.
337
-
338
- Implements the SwiGLU activation function combined with linear transformations,
339
- serving as the feed-forward network in transformer blocks.
340
-
341
- Args:
342
- config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
343
- - config.d_model: Model dimension
344
- - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
345
-
346
- References:
347
- https://arxiv.org/abs/2002.05202
348
- """
349
-
350
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
351
- super().__init__()
352
-
353
- model_dim = config.d_model
354
- act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
355
-
356
- self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
357
- self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
358
- self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
359
-
360
- def forward(self, x: torch.Tensor) -> torch.Tensor:
361
- return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
362
-
363
-
364
- ########################################################
365
- #
366
- # PicoDecoderBlock
367
- #
368
- ########################################################
369
-
370
-
371
- class PicoDecoderBlock(nn.Module):
372
- """Single Transformer Block with Attention and Feed-forward layers.
373
-
374
- Implements a standard transformer block with:
375
- - Multi-head attention with normalization and residual connection
376
- - SwiGLU feed-forward network with normalization and residual connection
377
-
378
- Args:
379
- config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
380
- a HuggingFace PicoDecoderHFConfig
381
- """
382
-
383
- def __init__(
384
- self,
385
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
386
- ):
387
- super().__init__()
388
-
389
- self.attention = Attention(config)
390
- self.swiglu = SwiGLU(config)
391
- self.attention_norm = RMSNorm(config)
392
- self.swiglu_norm = RMSNorm(config)
393
-
394
- def forward(
395
- self,
396
- input: torch.Tensor,
397
- mask: Optional[torch.Tensor] = None,
398
- past_key_values: Optional[Tuple[torch.Tensor]] = None,
399
- use_cache: bool = False,
400
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
401
- attention_output, cached_key_values = self.attention(
402
- self.attention_norm(input),
403
- mask=mask,
404
- past_key_values=past_key_values,
405
- use_cache=use_cache,
406
- )
407
- # NOTE: cached_key_values is None if use_cache is False
408
-
409
- h = input + attention_output
410
- out = h + self.swiglu(self.swiglu_norm(h))
411
- return out, cached_key_values
412
-
413
-
414
- ########################################################
415
- #
416
- # Pico Decoder (Causal Transformer Model)
417
- #
418
- ########################################################
419
-
420
-
421
- class PicoDecoder(nn.Module):
422
- """
423
- Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
424
- single autoregressive model.
425
-
426
- For more information on the model, see the classes for the modules that make up the model.
427
- """
428
-
429
- def __init__(
430
- self,
431
- model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
432
- ):
433
- super().__init__()
434
- self.config = model_config
435
-
436
- self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
437
- self.layers = nn.ModuleList(
438
- [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
439
- )
440
- self.output_norm = RMSNorm(self.config)
441
- self.de_embedding_proj = nn.Linear(
442
- self.config.d_model, self.config.vocab_size, bias=False
443
- )
444
-
445
- def convert_to_hf_model(self) -> "PicoDecoderHF":
446
- """Convert the Lightning model to a HuggingFace model."""
447
- # Create HF config without fabric-specific settings
448
- hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
449
-
450
- # Create new HF model
451
- hf_model = PicoDecoderHF(hf_config)
452
-
453
- # Copy state dict, excluding fabric-specific keys
454
- hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
455
-
456
- return hf_model
457
-
458
- def forward(
459
- self,
460
- input_ids: torch.Tensor,
461
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
462
- use_cache: bool = False,
463
- ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
464
- """
465
- This is the forward pass for the entire Pico model. It boils down to:
466
- - Embedding the input ids
467
- - Creating a causal mask
468
- - Processing through the pico layers
469
- - Projecting the output to logits
470
-
471
- NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
472
- generation by caching the KV pairs from previous forward passes. This is useful when doing
473
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
474
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
475
- its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
476
- KV caches (so a tuple of tuples).
477
- """
478
-
479
- seq_len = input_ids.shape[-1]
480
- h = self.embedding_proj(input_ids)
481
-
482
- # Calculate start position from past cached KV pairs. Remember that each layer has its
483
- # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
484
- # correct layer and then for either the keys or values.
485
- start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
486
-
487
- # Create causal mask for current sequence
488
- mask = None
489
- if seq_len > 1:
490
- mask = torch.full((seq_len, seq_len), float("-inf"))
491
- mask = torch.triu(mask, diagonal=1)
492
-
493
- # If using KV cache, extend mask to cover cached sequence length
494
- if past_key_values is not None:
495
- # Add zeros for cached tokens (we can attend to all of them)
496
- mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
497
-
498
- mask = mask.to(h.device)
499
-
500
- # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
501
- # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
502
- cached_key_values = () if use_cache else None
503
-
504
- # Process through transformer blocks
505
- for idx, layer in enumerate(self.layers):
506
- layer_past_key_values = (
507
- past_key_values[idx] if past_key_values is not None else None
508
- )
509
-
510
- h, layer_cached_key_values = layer(
511
- h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
512
- )
513
-
514
- if use_cache:
515
- cached_key_values += (layer_cached_key_values,)
516
-
517
- # Final norm and projection
518
- h = self.output_norm(h)
519
- logits = self.de_embedding_proj(h).float()
520
-
521
- return logits, cached_key_values
522
-
523
-
524
- ########################################################
525
- #
526
- # HuggingFace Wrapper for the Pico Decoder model.
527
- #
528
- ########################################################
529
-
530
-
531
- class PicoDecoderHFConfig(PretrainedConfig):
532
- """Config class for the Pico Decoder HuggingFace wrapper."""
533
-
534
- model_type = "pico_decoder"
535
-
536
- @classmethod
537
- def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
538
- """
539
- Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
540
- this is because with some kwargs special handling is required and can make this class
541
- brittle.
542
- """
543
- pico_config = cls(**config_dict)
544
-
545
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
546
- unused_kwargs = {
547
- key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
548
- }
549
-
550
- if return_unused_kwargs:
551
- return pico_config, unused_kwargs
552
- return pico_config
553
-
554
- @classmethod
555
- def from_dataclass(cls, model_config: "ModelConfig"):
556
- """Initialise from our custom config dataclass."""
557
- return cls.from_dict(asdict(model_config))
558
-
559
-
560
- class PicoDecoderHF(PreTrainedModel, GenerationMixin):
561
- """
562
- HuggingFace wrapper for the Pico model with generation support.
563
-
564
- Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
565
- wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
566
- Pico model as well as the model wrapped in this HuggingFace class.
567
-
568
- This also lets you do cool things like:
569
-
570
- `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
571
- """
572
-
573
- config_class = PicoDecoderHFConfig
574
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
575
- main_input_name = "input_ids"
576
-
577
- def __init__(self, config: PicoDecoderHFConfig):
578
- super().__init__(config)
579
- self.pico_decoder = PicoDecoder(config)
580
- # Initialize generation config with defaults
581
- self.generation_config = GenerationConfig()
582
- # Set some reasonable defaults for the model
583
- if hasattr(config, "max_position_embeddings"):
584
- self.generation_config.max_length = config.max_position_embeddings
585
- if hasattr(config, "vocab_size"):
586
- self.generation_config.vocab_size = config.vocab_size
587
-
588
- def forward(
589
- self,
590
- input_ids: torch.Tensor,
591
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
592
- use_cache: bool = False,
593
- **kwargs,
594
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
595
- """HuggingFace forward pass wrapper.
596
-
597
- Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
598
- Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
599
- """
600
- logits, past_key_values = self.pico_decoder(
601
- input_ids, past_key_values, use_cache
602
- )
603
- if use_cache:
604
- return CausalLMOutputWithPast(
605
- logits=logits,
606
- past_key_values=past_key_values,
607
- )
608
- else:
609
- return CausalLMOutput(
610
- logits=logits,
611
- )
612
-
613
- def prepare_inputs_for_generation(
614
- self,
615
- input_ids: torch.LongTensor,
616
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
617
- attention_mask: Optional[torch.LongTensor] = None,
618
- **kwargs,
619
- ) -> Dict[str, Any]:
620
- """
621
- Prepare inputs for generation.
622
-
623
- Args:
624
- input_ids: Input token IDs
625
- past_key_values: Cached key-value pairs from previous forward passes
626
- attention_mask: Attention mask for the input
627
- **kwargs: Additional arguments
628
-
629
- Returns:
630
- Dictionary containing prepared inputs
631
- """
632
- # If we have past_key_values, we only need the last token
633
- if past_key_values is not None:
634
- input_ids = input_ids[:, -1:]
635
-
636
- return {
637
- "input_ids": input_ids,
638
- "past_key_values": past_key_values,
639
- "use_cache": True,
640
- }
641
-
642
- def get_input_embeddings(self):
643
- """Get the input embeddings layer."""
644
- return self.pico_decoder.embedding_proj
645
-
646
- def set_input_embeddings(self, value):
647
- """Set the input embeddings layer."""
648
- self.pico_decoder.embedding_proj = value
649
-
650
- def get_output_embeddings(self):
651
- """Get the output embeddings layer."""
652
- return self.pico_decoder.de_embedding_proj
653
-
654
- def set_output_embeddings(self, value):
655
- """Set the output embeddings layer."""
656
- self.pico_decoder.de_embedding_proj = value
657
-
658
- def get_lm_head(self):
659
- """Get the language model head."""
660
- return self.pico_decoder.de_embedding_proj
661
-
662
- def can_generate(self) -> bool:
663
- """Check if the model can generate text."""
664
- return True
665
-
666
- @property
667
- def is_encoder_decoder(self) -> bool:
668
- """Check if the model is an encoder-decoder model."""
669
- return False
670
-
671
- @property
672
- def can_use_cache(self) -> bool:
673
- """Check if the model can use KV cache."""
674
- return True
675
-
676
- def resize_token_embeddings(
677
- self, new_num_tokens: Optional[int] = None
678
- ) -> torch.nn.Embedding:
679
- """Resize token embeddings."""
680
- old_embeddings = self.get_input_embeddings()
681
- if new_num_tokens is None:
682
- new_num_tokens = old_embeddings.num_embeddings
683
-
684
- new_embeddings = torch.nn.Embedding(
685
- new_num_tokens, old_embeddings.embedding_dim
686
- )
687
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
688
- old_embeddings.weight.data
689
- )
690
-
691
- self.pico_decoder.embedding_proj = new_embeddings
692
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
693
- old_embeddings.embedding_dim, new_num_tokens, bias=False
694
- )
695
-
696
- return new_embeddings
697
-
698
-
699
- # Register for auto classes
700
- PicoDecoderHFConfig.register_for_auto_class()
701
- PicoDecoderHF.register_for_auto_class("AutoModel")
702
- PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
703
-
704
-
705
- ########################################################
706
- #
707
- # New PicoDecoderForCausalLM class for generation support
708
- #
709
- ########################################################
710
-
711
-
712
- class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
713
- """
714
- PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
715
-
716
- This class is designed to work with existing checkpoints and provides full generation support.
717
- It inherits from the right base classes that HuggingFace expects for text generation.
718
- """
719
-
720
- config_class = PicoDecoderHFConfig
721
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
722
- main_input_name = "input_ids"
723
-
724
- def __init__(self, config: PicoDecoderHFConfig):
725
- super().__init__(config)
726
- self.pico_decoder = PicoDecoder(config)
727
- # Initialize generation config with defaults
728
- self.generation_config = GenerationConfig()
729
- # Set some reasonable defaults for the model
730
- if hasattr(config, "max_position_embeddings"):
731
- self.generation_config.max_length = config.max_position_embeddings
732
- if hasattr(config, "vocab_size"):
733
- self.generation_config.vocab_size = config.vocab_size
734
-
735
- def forward(
736
- self,
737
- input_ids: torch.Tensor,
738
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
739
- use_cache: bool = False,
740
- **kwargs,
741
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
742
- """Forward pass for text generation."""
743
- logits, past_key_values = self.pico_decoder(
744
- input_ids, past_key_values, use_cache
745
- )
746
- if use_cache:
747
- return CausalLMOutputWithPast(
748
- logits=logits,
749
- past_key_values=past_key_values,
750
- )
751
- else:
752
- return CausalLMOutput(
753
- logits=logits,
754
- )
755
-
756
- def prepare_inputs_for_generation(
757
- self,
758
- input_ids: torch.LongTensor,
759
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
760
- attention_mask: Optional[torch.LongTensor] = None,
761
- **kwargs,
762
- ) -> Dict[str, Any]:
763
- """Prepare inputs for generation."""
764
- # If we have past_key_values, we only need the last token
765
- if past_key_values is not None:
766
- input_ids = input_ids[:, -1:]
767
-
768
- return {
769
- "input_ids": input_ids,
770
- "past_key_values": past_key_values,
771
- "use_cache": True,
772
- }
773
-
774
- def get_input_embeddings(self):
775
- """Get the input embeddings layer."""
776
- return self.pico_decoder.embedding_proj
777
-
778
- def set_input_embeddings(self, value):
779
- """Set the input embeddings layer."""
780
- self.pico_decoder.embedding_proj = value
781
-
782
- def get_output_embeddings(self):
783
- """Get the output embeddings layer."""
784
- return self.pico_decoder.de_embedding_proj
785
-
786
- def set_output_embeddings(self, value):
787
- """Set the output embeddings layer."""
788
- self.pico_decoder.de_embedding_proj = value
789
-
790
- def get_lm_head(self):
791
- """Get the language model head."""
792
- return self.pico_decoder.de_embedding_proj
793
-
794
- def can_generate(self) -> bool:
795
- """Check if the model can generate text."""
796
- return True
797
-
798
- @property
799
- def is_encoder_decoder(self) -> bool:
800
- """Check if the model is an encoder-decoder model."""
801
- return False
802
-
803
- @property
804
- def can_use_cache(self) -> bool:
805
- """Check if the model can use KV cache."""
806
- return True
807
-
808
- def resize_token_embeddings(
809
- self, new_num_tokens: Optional[int] = None
810
- ) -> torch.nn.Embedding:
811
- """Resize token embeddings."""
812
- old_embeddings = self.get_input_embeddings()
813
- if new_num_tokens is None:
814
- new_num_tokens = old_embeddings.num_embeddings
815
-
816
- new_embeddings = torch.nn.Embedding(
817
- new_num_tokens, old_embeddings.embedding_dim
818
- )
819
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
820
- old_embeddings.weight.data
821
- )
822
-
823
- self.pico_decoder.embedding_proj = new_embeddings
824
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
825
- old_embeddings.embedding_dim, new_num_tokens, bias=False
826
- )
827
-
828
- return new_embeddings
829
-
830
- @classmethod
831
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
832
- """
833
- Load a pretrained model from a checkpoint.
834
-
835
- This method handles loading from both the old PicoDecoderHF format and the new format.
836
- """
837
- # First try to load with the new class
838
- try:
839
- return super().from_pretrained(
840
- pretrained_model_name_or_path, *model_args, **kwargs
841
- )
842
- except Exception as e:
843
- print(f"Failed to load with new class: {e}")
844
- print("Attempting to load with legacy class and convert...")
845
-
846
- # Try to load with the old class and convert
847
- try:
848
- from transformers import AutoModel
849
-
850
- old_model = AutoModel.from_pretrained(
851
- pretrained_model_name_or_path,
852
- trust_remote_code=True,
853
- *model_args,
854
- **kwargs,
855
- )
856
-
857
- # Create new model instance
858
- new_model = cls(old_model.config)
859
-
860
- # Copy state dict
861
- new_model.load_state_dict(old_model.state_dict(), strict=False)
862
-
863
- return new_model
864
-
865
- except Exception as e2:
866
- print(f"Failed to convert from legacy format: {e2}")
867
- raise e
868
-
869
-
870
- # Register the new class
871
- PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/special_tokens_map.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "eos_token": {
3
- "content": "<|endoftext|>",
4
- "lstrip": false,
5
- "normalized": false,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "pad_token": {
10
- "content": "<|padding|>",
11
- "lstrip": false,
12
- "normalized": false,
13
- "rstrip": false,
14
- "single_word": false
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma29k/checkpoints/step_0/tokenizer_config.json DELETED
@@ -1,239 +0,0 @@
1
- {
2
- "add_bos_token": false,
3
- "add_eos_token": false,
4
- "add_prefix_space": false,
5
- "added_tokens_decoder": {
6
- "0": {
7
- "content": "|||IP_ADDRESS|||",
8
- "lstrip": false,
9
- "normalized": true,
10
- "rstrip": false,
11
- "single_word": false,
12
- "special": false
13
- },
14
- "1": {
15
- "content": "<|padding|>",
16
- "lstrip": false,
17
- "normalized": false,
18
- "rstrip": false,
19
- "single_word": false,
20
- "special": true
21
- },
22
- "50254": {
23
- "content": " ",
24
- "lstrip": false,
25
- "normalized": true,
26
- "rstrip": false,
27
- "single_word": false,
28
- "special": false
29
- },
30
- "50255": {
31
- "content": " ",
32
- "lstrip": false,
33
- "normalized": true,
34
- "rstrip": false,
35
- "single_word": false,
36
- "special": false
37
- },
38
- "50256": {
39
- "content": " ",
40
- "lstrip": false,
41
- "normalized": true,
42
- "rstrip": false,
43
- "single_word": false,
44
- "special": false
45
- },
46
- "50257": {
47
- "content": " ",
48
- "lstrip": false,
49
- "normalized": true,
50
- "rstrip": false,
51
- "single_word": false,
52
- "special": false
53
- },
54
- "50258": {
55
- "content": " ",
56
- "lstrip": false,
57
- "normalized": true,
58
- "rstrip": false,
59
- "single_word": false,
60
- "special": false
61
- },
62
- "50259": {
63
- "content": " ",
64
- "lstrip": false,
65
- "normalized": true,
66
- "rstrip": false,
67
- "single_word": false,
68
- "special": false
69
- },
70
- "50260": {
71
- "content": " ",
72
- "lstrip": false,
73
- "normalized": true,
74
- "rstrip": false,
75
- "single_word": false,
76
- "special": false
77
- },
78
- "50261": {
79
- "content": " ",
80
- "lstrip": false,
81
- "normalized": true,
82
- "rstrip": false,
83
- "single_word": false,
84
- "special": false
85
- },
86
- "50262": {
87
- "content": " ",
88
- "lstrip": false,
89
- "normalized": true,
90
- "rstrip": false,
91
- "single_word": false,
92
- "special": false
93
- },
94
- "50263": {
95
- "content": " ",
96
- "lstrip": false,
97
- "normalized": true,
98
- "rstrip": false,
99
- "single_word": false,
100
- "special": false
101
- },
102
- "50264": {
103
- "content": " ",
104
- "lstrip": false,
105
- "normalized": true,
106
- "rstrip": false,
107
- "single_word": false,
108
- "special": false
109
- },
110
- "50265": {
111
- "content": " ",
112
- "lstrip": false,
113
- "normalized": true,
114
- "rstrip": false,
115
- "single_word": false,
116
- "special": false
117
- },
118
- "50266": {
119
- "content": " ",
120
- "lstrip": false,
121
- "normalized": true,
122
- "rstrip": false,
123
- "single_word": false,
124
- "special": false
125
- },
126
- "50267": {
127
- "content": " ",
128
- "lstrip": false,
129
- "normalized": true,
130
- "rstrip": false,
131
- "single_word": false,
132
- "special": false
133
- },
134
- "50268": {
135
- "content": " ",
136
- "lstrip": false,
137
- "normalized": true,
138
- "rstrip": false,
139
- "single_word": false,
140
- "special": false
141
- },
142
- "50269": {
143
- "content": " ",
144
- "lstrip": false,
145
- "normalized": true,
146
- "rstrip": false,
147
- "single_word": false,
148
- "special": false
149
- },
150
- "50270": {
151
- "content": " ",
152
- "lstrip": false,
153
- "normalized": true,
154
- "rstrip": false,
155
- "single_word": false,
156
- "special": false
157
- },
158
- "50271": {
159
- "content": " ",
160
- "lstrip": false,
161
- "normalized": true,
162
- "rstrip": false,
163
- "single_word": false,
164
- "special": false
165
- },
166
- "50272": {
167
- "content": " ",
168
- "lstrip": false,
169
- "normalized": true,
170
- "rstrip": false,
171
- "single_word": false,
172
- "special": false
173
- },
174
- "50273": {
175
- "content": " ",
176
- "lstrip": false,
177
- "normalized": true,
178
- "rstrip": false,
179
- "single_word": false,
180
- "special": false
181
- },
182
- "50274": {
183
- "content": " ",
184
- "lstrip": false,
185
- "normalized": true,
186
- "rstrip": false,
187
- "single_word": false,
188
- "special": false
189
- },
190
- "50275": {
191
- "content": " ",
192
- "lstrip": false,
193
- "normalized": true,
194
- "rstrip": false,
195
- "single_word": false,
196
- "special": false
197
- },
198
- "50276": {
199
- "content": " ",
200
- "lstrip": false,
201
- "normalized": true,
202
- "rstrip": false,
203
- "single_word": false,
204
- "special": false
205
- },
206
- "50277": {
207
- "content": "|||EMAIL_ADDRESS|||",
208
- "lstrip": false,
209
- "normalized": true,
210
- "rstrip": false,
211
- "single_word": false,
212
- "special": false
213
- },
214
- "50278": {
215
- "content": "|||PHONE_NUMBER|||",
216
- "lstrip": false,
217
- "normalized": true,
218
- "rstrip": false,
219
- "single_word": false,
220
- "special": false
221
- },
222
- "50279": {
223
- "content": "<|endoftext|>",
224
- "lstrip": false,
225
- "normalized": false,
226
- "rstrip": false,
227
- "single_word": false,
228
- "special": true
229
- }
230
- },
231
- "bos_token": null,
232
- "clean_up_tokenization_spaces": true,
233
- "eos_token": "<|endoftext|>",
234
- "extra_special_tokens": {},
235
- "model_max_length": 1000000000000000019884624838656,
236
- "pad_token": "<|padding|>",
237
- "tokenizer_class": "GPTNeoXTokenizer",
238
- "unk_token": null
239
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/config.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "activation_hidden_dim": 384,
3
- "architectures": [
4
- "PicoDecoderHF"
5
- ],
6
- "attention_n_heads": 12,
7
- "attention_n_kv_heads": 4,
8
- "auto_map": {
9
- "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
- "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
- },
12
- "batch_size": 1024,
13
- "d_model": 96,
14
- "max_seq_len": 2048,
15
- "model_type": "pico_decoder",
16
- "n_layers": 12,
17
- "norm_eps": 1e-06,
18
- "position_emb_theta": 10000.0,
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.48.3",
21
- "vocab_size": 50304
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/fabric_state/checkpoint.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:17ff5a63c62f790e672d56533953307ebf889aaf949cb56ac6808c3c7ad3b766
3
- size 135543171
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/generation_config.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "transformers_version": "4.48.3",
3
- "vocab_size": 50304
4
- }
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_activations.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:86f6894546ad09ade1abd8844ed94f3c6d1eb34b6f3278911777c42c50ddd512
3
- size 33819
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_data/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:908a522ebc31350e7397330dd99c9ad2253e2efdc1bee1ff76f9babef6078d28
3
- size 66408
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_data/dataset_info.json DELETED
@@ -1,19 +0,0 @@
1
- {
2
- "citation": "",
3
- "description": "",
4
- "features": {
5
- "input_ids": {
6
- "feature": {
7
- "dtype": "int32",
8
- "_type": "Value"
9
- },
10
- "_type": "Sequence"
11
- },
12
- "text": {
13
- "dtype": "string",
14
- "_type": "Value"
15
- }
16
- },
17
- "homepage": "",
18
- "license": ""
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_data/state.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "_data_files": [
3
- {
4
- "filename": "data-00000-of-00001.arrow"
5
- }
6
- ],
7
- "_fingerprint": "86e249409514e027",
8
- "_format_columns": null,
9
- "_format_kwargs": {},
10
- "_format_type": null,
11
- "_output_all_columns": false,
12
- "_split": null
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_gradients.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:49877366118749ef2726caa90137d614bcca1690852e7dcf1d992f0fe93a9189
3
- size 2371527
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/learning_dynamics/train_weights.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ed0f3b897d6c11ea78de275cfe4e3f7e56b3bc07ed14d9b64ced24f1bce4fda5
3
- size 2371443
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e55c6c8d229fb062b568dc93714a06329caa160f5aae42d5394b0976ff3f8b9f
3
- size 45143592
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/pico_decoder.py DELETED
@@ -1,871 +0,0 @@
1
- """
2
- Pico Decoder: A Lightweight Causal Transformer Language Model
3
-
4
- Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
-
6
- Everything is written with a modular design for easy modification and experimentation.
7
-
8
- Key features:
9
- - RMSNorm for layer normalization
10
- - Rotary Positional Embeddings (RoPE)
11
- - Multi-head attention with KV-cache support
12
- - SwiGLU activation function
13
- - Residual connections throughout
14
-
15
- - KV-cache for faster autoregressive generation
16
-
17
- References:
18
- - RoPE: https://arxiv.org/abs/2104.09864
19
- - SwiGLU: https://arxiv.org/abs/2002.05202
20
- - LLAMA: https://arxiv.org/abs/2302.13971
21
-
22
- Adapted from:
23
- - OLMO: https://github.com/allenai/OLMo
24
- - LLAMA: https://github.com/meta/llama
25
- """
26
-
27
- from dataclasses import asdict
28
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
-
30
- import torch
31
- import torch.nn as nn
32
- import torch.nn.functional as F
33
- from torch.nn.attention import SDPBackend, sdpa_kernel
34
- from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
35
- from transformers.generation import GenerationConfig
36
- from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
37
-
38
- try:
39
- if TYPE_CHECKING:
40
- # We need to do this to avoid importing these when creating the HF-compatible models
41
- from src.config import ModelConfig
42
- except ImportError:
43
- pass
44
-
45
- ########################################################
46
- #
47
- # Layer Normalization
48
- #
49
- ########################################################
50
-
51
-
52
- class RMSNorm(torch.nn.Module):
53
- """Root Mean Square Layer Normalization.
54
-
55
- A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
56
- resulting in improved stability and performance.
57
-
58
- Args:
59
- config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
60
- - config.norm_eps: Small constant for numerical stability
61
- - config.d_model: Model dimension for the weight parameter
62
-
63
- References:
64
- https://arxiv.org/abs/1910.07467
65
- """
66
-
67
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
68
- super().__init__()
69
- self.eps = config.norm_eps
70
- self.weight = nn.Parameter(torch.ones(config.d_model))
71
-
72
- def _norm(self, x: torch.Tensor) -> torch.Tensor:
73
- """
74
- Normalizes the input tensor by its RMS value.
75
- """
76
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
77
-
78
- def forward(self, x: torch.Tensor) -> torch.Tensor:
79
- """
80
- Applies RMS normalization to the input tensor and scales it by the weight parameter.
81
- """
82
- output = self._norm(x.float()).type_as(x)
83
- return output * self.weight
84
-
85
-
86
- ########################################################
87
- #
88
- # Positional Embedding
89
- #
90
- ########################################################
91
-
92
-
93
- class RoPE(nn.Module):
94
- """Rotary Positional Embeddings (RoPE).
95
-
96
- Implements position-dependent rotation of keys and queries in attention mechanism,
97
- allowing better modeling of relative positions in sequences. Uses complex number
98
- operations for efficient rotation.
99
-
100
- Args:
101
- config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
102
- - config.position_emb_theta: Base for frequency computation
103
- - config.d_model: Model dimension
104
- - config.attention_n_heads: Number of attention heads
105
- - config.max_seq_len: Maximum sequence length
106
-
107
- References:
108
- https://arxiv.org/abs/2104.09864
109
- """
110
-
111
- _freqs_cis_tensor: torch.Tensor | None = None
112
-
113
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
114
- super().__init__()
115
-
116
- self.theta = config.position_emb_theta
117
- self.dim = config.d_model // config.attention_n_heads
118
-
119
- max_seq_len = config.max_seq_len
120
-
121
- # only gets set once, and then reused for all RoPE instances
122
- if RoPE._freqs_cis_tensor is None:
123
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(
124
- max_seq_len, self.theta, self.dim
125
- )
126
-
127
- # register _freqs_cis buffer
128
- # can be easily recomputed so persistent=False
129
- self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
130
-
131
- @classmethod
132
- def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
133
- """Setup Frequency Tensor for RoPE Embeddings
134
-
135
- Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
136
-
137
- Note other implementations will use cos and sin directly, but using the complex
138
- number representation is (probably) more efficient:
139
-
140
- e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
141
- """
142
- _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
143
- positions = torch.arange(seq_len)
144
- freqs = torch.outer(positions, _freqs)
145
- return torch.polar(torch.ones_like(freqs), freqs) # complex64
146
-
147
- def get_freqs_cis(
148
- self, input_shape: torch.Size, start_pos: int, end_pos: int
149
- ) -> torch.Tensor:
150
- """Reshape Frequency Tensor for RoPE Embeddings
151
-
152
- Makes the frequency tensor broadcastable with the input tensor.
153
- """
154
- _freqs_cis = self._freqs_cis[start_pos:end_pos]
155
- ndim = len(input_shape)
156
- assert 0 <= 1 < ndim
157
- assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
158
-
159
- # TODO: Check whether this is correct (might be able to remove this)
160
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
161
- return _freqs_cis.view(*shape)
162
-
163
- def forward(
164
- self,
165
- queries: torch.Tensor,
166
- keys: torch.Tensor,
167
- start_pos: int = 0,
168
- ) -> Tuple[torch.Tensor, torch.Tensor]:
169
- """Apply RoPE Embeddings to Queries and Keys
170
-
171
- Applies the rotary positional embeddings to the input tensors via complex num multiplication
172
-
173
- NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
174
- """
175
- queries_ = torch.view_as_complex(
176
- queries.float().reshape(*queries.shape[:-1], -1, 2)
177
- )
178
- keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
179
-
180
- input_shape = (
181
- queries_.shape
182
- ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
183
- freqs_start_pos = start_pos
184
- freqs_end_pos = freqs_start_pos + queries_.shape[1]
185
-
186
- freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
187
-
188
- queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
189
- keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
190
- return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
191
-
192
-
193
- ########################################################
194
- #
195
- # Attention
196
- #
197
- ########################################################
198
-
199
-
200
- class Attention(nn.Module):
201
- """Multi-head Attention with Group Query Attention support.
202
-
203
- Implements scaled dot-product attention and supports:
204
- - Grouped Query Attention (GQA)
205
- - Key-Value caching for efficient inference
206
- - RoPE integration
207
-
208
- Args:
209
- config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
210
- - config.attention_n_heads: Number of attention heads
211
- - config.attention_n_kv_heads: Number of key/value heads
212
- - config.d_model: Model dimension
213
- - config.batch_size: Maximum batch size
214
- - config.max_seq_len: Maximum sequence length
215
-
216
- Shape:
217
- - Input: (batch_size, seq_len, d_model)
218
- - Output: (batch_size, seq_len, d_model)
219
- """
220
-
221
- def __init__(
222
- self,
223
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
224
- ):
225
- super().__init__()
226
-
227
- self.n_heads = config.attention_n_heads
228
- self.n_kv_heads = config.attention_n_kv_heads
229
-
230
- self.batch_size = config.batch_size
231
- self.max_seq_len = config.max_seq_len
232
-
233
- d_model = config.d_model
234
- self.head_dim = d_model // self.n_heads
235
-
236
- self.n_rep = self.n_heads // self.n_kv_heads
237
-
238
- self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
239
- self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
240
- self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
241
- self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
242
-
243
- self.rope = RoPE(config)
244
-
245
- def forward(
246
- self,
247
- input: torch.Tensor,
248
- mask: Optional[torch.Tensor] = None,
249
- past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
250
- use_cache: bool = False,
251
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
252
- """Forward pass for the attention mechanism.
253
-
254
- Computes queries, keys, and values for the attention mechanism. Applies rotary positional
255
- embeddings to the queries and keys, and then computes attention scores and outputs.
256
-
257
- For an introduction to the attention mechanism, see:
258
- https://arxiv.org/abs/1706.03762
259
-
260
- A few things to note:
261
- - The past_key_values is used to implement the KV cache, which is used to speed up
262
- generation by caching the KV pairs from previous forward passes. This is useful when doing
263
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
264
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
265
- its own KV cache - this KV cache is implemented as a tuple.
266
- """
267
- bsz, seq_len, _ = input.shape
268
- _queries, _keys, _values = (
269
- self.q_proj(input),
270
- self.k_proj(input),
271
- self.v_proj(input),
272
- )
273
-
274
- # Reshaping for multi-head attention
275
- queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
276
- keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
277
- values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
278
-
279
- # The start position is used to apply the RoPE embeddings to only the new tokens
280
- # when using the kv_cache in the attention mechanism.
281
- # We want to start from the last position in the cache.
282
- start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
283
-
284
- # apply rotary positional embeddings
285
- queries, keys = self.rope(queries, keys, start_pos)
286
-
287
- if past_key_values is not None:
288
- keys = torch.cat([past_key_values[0], keys], dim=1)
289
- values = torch.cat([past_key_values[1], values], dim=1)
290
-
291
- if use_cache:
292
- cached_keys = keys
293
- cached_values = values
294
- else:
295
- cached_keys = None
296
- cached_values = None
297
-
298
- queries = queries.transpose(1, 2)
299
- keys = keys.transpose(1, 2)
300
- values = values.transpose(1, 2)
301
-
302
- apply_gqa = self.n_rep > 1
303
- if apply_gqa and queries.device.type == "mps":
304
- # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
305
- # outside of the kernel to get the same effect.
306
- # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
307
- keys = keys.repeat_interleave(self.n_rep, dim=-3)
308
- values = values.repeat_interleave(self.n_rep, dim=-3)
309
- apply_gqa = False
310
-
311
- backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
312
-
313
- with sdpa_kernel(backends=backends):
314
- attn_output = F.scaled_dot_product_attention(
315
- queries.contiguous(),
316
- keys.contiguous(),
317
- values.contiguous(),
318
- attn_mask=mask.to(queries.dtype) if mask is not None else None,
319
- enable_gqa=apply_gqa,
320
- )
321
-
322
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
323
- output = self.o_proj(attn_output)
324
-
325
- return output, (cached_keys, cached_values)
326
-
327
-
328
- ########################################################
329
- #
330
- # SwiGLU (Combines MLP and Activation)
331
- #
332
- ########################################################
333
-
334
-
335
- class SwiGLU(nn.Module):
336
- """SwiGLU Activation Function with Linear Projections.
337
-
338
- Implements the SwiGLU activation function combined with linear transformations,
339
- serving as the feed-forward network in transformer blocks.
340
-
341
- Args:
342
- config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
343
- - config.d_model: Model dimension
344
- - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
345
-
346
- References:
347
- https://arxiv.org/abs/2002.05202
348
- """
349
-
350
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
351
- super().__init__()
352
-
353
- model_dim = config.d_model
354
- act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
355
-
356
- self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
357
- self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
358
- self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
359
-
360
- def forward(self, x: torch.Tensor) -> torch.Tensor:
361
- return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
362
-
363
-
364
- ########################################################
365
- #
366
- # PicoDecoderBlock
367
- #
368
- ########################################################
369
-
370
-
371
- class PicoDecoderBlock(nn.Module):
372
- """Single Transformer Block with Attention and Feed-forward layers.
373
-
374
- Implements a standard transformer block with:
375
- - Multi-head attention with normalization and residual connection
376
- - SwiGLU feed-forward network with normalization and residual connection
377
-
378
- Args:
379
- config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
380
- a HuggingFace PicoDecoderHFConfig
381
- """
382
-
383
- def __init__(
384
- self,
385
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
386
- ):
387
- super().__init__()
388
-
389
- self.attention = Attention(config)
390
- self.swiglu = SwiGLU(config)
391
- self.attention_norm = RMSNorm(config)
392
- self.swiglu_norm = RMSNorm(config)
393
-
394
- def forward(
395
- self,
396
- input: torch.Tensor,
397
- mask: Optional[torch.Tensor] = None,
398
- past_key_values: Optional[Tuple[torch.Tensor]] = None,
399
- use_cache: bool = False,
400
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
401
- attention_output, cached_key_values = self.attention(
402
- self.attention_norm(input),
403
- mask=mask,
404
- past_key_values=past_key_values,
405
- use_cache=use_cache,
406
- )
407
- # NOTE: cached_key_values is None if use_cache is False
408
-
409
- h = input + attention_output
410
- out = h + self.swiglu(self.swiglu_norm(h))
411
- return out, cached_key_values
412
-
413
-
414
- ########################################################
415
- #
416
- # Pico Decoder (Causal Transformer Model)
417
- #
418
- ########################################################
419
-
420
-
421
- class PicoDecoder(nn.Module):
422
- """
423
- Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
424
- single autoregressive model.
425
-
426
- For more information on the model, see the classes for the modules that make up the model.
427
- """
428
-
429
- def __init__(
430
- self,
431
- model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
432
- ):
433
- super().__init__()
434
- self.config = model_config
435
-
436
- self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
437
- self.layers = nn.ModuleList(
438
- [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
439
- )
440
- self.output_norm = RMSNorm(self.config)
441
- self.de_embedding_proj = nn.Linear(
442
- self.config.d_model, self.config.vocab_size, bias=False
443
- )
444
-
445
- def convert_to_hf_model(self) -> "PicoDecoderHF":
446
- """Convert the Lightning model to a HuggingFace model."""
447
- # Create HF config without fabric-specific settings
448
- hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
449
-
450
- # Create new HF model
451
- hf_model = PicoDecoderHF(hf_config)
452
-
453
- # Copy state dict, excluding fabric-specific keys
454
- hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
455
-
456
- return hf_model
457
-
458
- def forward(
459
- self,
460
- input_ids: torch.Tensor,
461
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
462
- use_cache: bool = False,
463
- ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
464
- """
465
- This is the forward pass for the entire Pico model. It boils down to:
466
- - Embedding the input ids
467
- - Creating a causal mask
468
- - Processing through the pico layers
469
- - Projecting the output to logits
470
-
471
- NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
472
- generation by caching the KV pairs from previous forward passes. This is useful when doing
473
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
474
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
475
- its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
476
- KV caches (so a tuple of tuples).
477
- """
478
-
479
- seq_len = input_ids.shape[-1]
480
- h = self.embedding_proj(input_ids)
481
-
482
- # Calculate start position from past cached KV pairs. Remember that each layer has its
483
- # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
484
- # correct layer and then for either the keys or values.
485
- start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
486
-
487
- # Create causal mask for current sequence
488
- mask = None
489
- if seq_len > 1:
490
- mask = torch.full((seq_len, seq_len), float("-inf"))
491
- mask = torch.triu(mask, diagonal=1)
492
-
493
- # If using KV cache, extend mask to cover cached sequence length
494
- if past_key_values is not None:
495
- # Add zeros for cached tokens (we can attend to all of them)
496
- mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
497
-
498
- mask = mask.to(h.device)
499
-
500
- # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
501
- # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
502
- cached_key_values = () if use_cache else None
503
-
504
- # Process through transformer blocks
505
- for idx, layer in enumerate(self.layers):
506
- layer_past_key_values = (
507
- past_key_values[idx] if past_key_values is not None else None
508
- )
509
-
510
- h, layer_cached_key_values = layer(
511
- h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
512
- )
513
-
514
- if use_cache:
515
- cached_key_values += (layer_cached_key_values,)
516
-
517
- # Final norm and projection
518
- h = self.output_norm(h)
519
- logits = self.de_embedding_proj(h).float()
520
-
521
- return logits, cached_key_values
522
-
523
-
524
- ########################################################
525
- #
526
- # HuggingFace Wrapper for the Pico Decoder model.
527
- #
528
- ########################################################
529
-
530
-
531
- class PicoDecoderHFConfig(PretrainedConfig):
532
- """Config class for the Pico Decoder HuggingFace wrapper."""
533
-
534
- model_type = "pico_decoder"
535
-
536
- @classmethod
537
- def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
538
- """
539
- Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
540
- this is because with some kwargs special handling is required and can make this class
541
- brittle.
542
- """
543
- pico_config = cls(**config_dict)
544
-
545
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
546
- unused_kwargs = {
547
- key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
548
- }
549
-
550
- if return_unused_kwargs:
551
- return pico_config, unused_kwargs
552
- return pico_config
553
-
554
- @classmethod
555
- def from_dataclass(cls, model_config: "ModelConfig"):
556
- """Initialise from our custom config dataclass."""
557
- return cls.from_dict(asdict(model_config))
558
-
559
-
560
- class PicoDecoderHF(PreTrainedModel, GenerationMixin):
561
- """
562
- HuggingFace wrapper for the Pico model with generation support.
563
-
564
- Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
565
- wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
566
- Pico model as well as the model wrapped in this HuggingFace class.
567
-
568
- This also lets you do cool things like:
569
-
570
- `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
571
- """
572
-
573
- config_class = PicoDecoderHFConfig
574
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
575
- main_input_name = "input_ids"
576
-
577
- def __init__(self, config: PicoDecoderHFConfig):
578
- super().__init__(config)
579
- self.pico_decoder = PicoDecoder(config)
580
- # Initialize generation config with defaults
581
- self.generation_config = GenerationConfig()
582
- # Set some reasonable defaults for the model
583
- if hasattr(config, "max_position_embeddings"):
584
- self.generation_config.max_length = config.max_position_embeddings
585
- if hasattr(config, "vocab_size"):
586
- self.generation_config.vocab_size = config.vocab_size
587
-
588
- def forward(
589
- self,
590
- input_ids: torch.Tensor,
591
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
592
- use_cache: bool = False,
593
- **kwargs,
594
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
595
- """HuggingFace forward pass wrapper.
596
-
597
- Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
598
- Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
599
- """
600
- logits, past_key_values = self.pico_decoder(
601
- input_ids, past_key_values, use_cache
602
- )
603
- if use_cache:
604
- return CausalLMOutputWithPast(
605
- logits=logits,
606
- past_key_values=past_key_values,
607
- )
608
- else:
609
- return CausalLMOutput(
610
- logits=logits,
611
- )
612
-
613
- def prepare_inputs_for_generation(
614
- self,
615
- input_ids: torch.LongTensor,
616
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
617
- attention_mask: Optional[torch.LongTensor] = None,
618
- **kwargs,
619
- ) -> Dict[str, Any]:
620
- """
621
- Prepare inputs for generation.
622
-
623
- Args:
624
- input_ids: Input token IDs
625
- past_key_values: Cached key-value pairs from previous forward passes
626
- attention_mask: Attention mask for the input
627
- **kwargs: Additional arguments
628
-
629
- Returns:
630
- Dictionary containing prepared inputs
631
- """
632
- # If we have past_key_values, we only need the last token
633
- if past_key_values is not None:
634
- input_ids = input_ids[:, -1:]
635
-
636
- return {
637
- "input_ids": input_ids,
638
- "past_key_values": past_key_values,
639
- "use_cache": True,
640
- }
641
-
642
- def get_input_embeddings(self):
643
- """Get the input embeddings layer."""
644
- return self.pico_decoder.embedding_proj
645
-
646
- def set_input_embeddings(self, value):
647
- """Set the input embeddings layer."""
648
- self.pico_decoder.embedding_proj = value
649
-
650
- def get_output_embeddings(self):
651
- """Get the output embeddings layer."""
652
- return self.pico_decoder.de_embedding_proj
653
-
654
- def set_output_embeddings(self, value):
655
- """Set the output embeddings layer."""
656
- self.pico_decoder.de_embedding_proj = value
657
-
658
- def get_lm_head(self):
659
- """Get the language model head."""
660
- return self.pico_decoder.de_embedding_proj
661
-
662
- def can_generate(self) -> bool:
663
- """Check if the model can generate text."""
664
- return True
665
-
666
- @property
667
- def is_encoder_decoder(self) -> bool:
668
- """Check if the model is an encoder-decoder model."""
669
- return False
670
-
671
- @property
672
- def can_use_cache(self) -> bool:
673
- """Check if the model can use KV cache."""
674
- return True
675
-
676
- def resize_token_embeddings(
677
- self, new_num_tokens: Optional[int] = None
678
- ) -> torch.nn.Embedding:
679
- """Resize token embeddings."""
680
- old_embeddings = self.get_input_embeddings()
681
- if new_num_tokens is None:
682
- new_num_tokens = old_embeddings.num_embeddings
683
-
684
- new_embeddings = torch.nn.Embedding(
685
- new_num_tokens, old_embeddings.embedding_dim
686
- )
687
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
688
- old_embeddings.weight.data
689
- )
690
-
691
- self.pico_decoder.embedding_proj = new_embeddings
692
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
693
- old_embeddings.embedding_dim, new_num_tokens, bias=False
694
- )
695
-
696
- return new_embeddings
697
-
698
-
699
- # Register for auto classes
700
- PicoDecoderHFConfig.register_for_auto_class()
701
- PicoDecoderHF.register_for_auto_class("AutoModel")
702
- PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
703
-
704
-
705
- ########################################################
706
- #
707
- # New PicoDecoderForCausalLM class for generation support
708
- #
709
- ########################################################
710
-
711
-
712
- class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
713
- """
714
- PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
715
-
716
- This class is designed to work with existing checkpoints and provides full generation support.
717
- It inherits from the right base classes that HuggingFace expects for text generation.
718
- """
719
-
720
- config_class = PicoDecoderHFConfig
721
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
722
- main_input_name = "input_ids"
723
-
724
- def __init__(self, config: PicoDecoderHFConfig):
725
- super().__init__(config)
726
- self.pico_decoder = PicoDecoder(config)
727
- # Initialize generation config with defaults
728
- self.generation_config = GenerationConfig()
729
- # Set some reasonable defaults for the model
730
- if hasattr(config, "max_position_embeddings"):
731
- self.generation_config.max_length = config.max_position_embeddings
732
- if hasattr(config, "vocab_size"):
733
- self.generation_config.vocab_size = config.vocab_size
734
-
735
- def forward(
736
- self,
737
- input_ids: torch.Tensor,
738
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
739
- use_cache: bool = False,
740
- **kwargs,
741
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
742
- """Forward pass for text generation."""
743
- logits, past_key_values = self.pico_decoder(
744
- input_ids, past_key_values, use_cache
745
- )
746
- if use_cache:
747
- return CausalLMOutputWithPast(
748
- logits=logits,
749
- past_key_values=past_key_values,
750
- )
751
- else:
752
- return CausalLMOutput(
753
- logits=logits,
754
- )
755
-
756
- def prepare_inputs_for_generation(
757
- self,
758
- input_ids: torch.LongTensor,
759
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
760
- attention_mask: Optional[torch.LongTensor] = None,
761
- **kwargs,
762
- ) -> Dict[str, Any]:
763
- """Prepare inputs for generation."""
764
- # If we have past_key_values, we only need the last token
765
- if past_key_values is not None:
766
- input_ids = input_ids[:, -1:]
767
-
768
- return {
769
- "input_ids": input_ids,
770
- "past_key_values": past_key_values,
771
- "use_cache": True,
772
- }
773
-
774
- def get_input_embeddings(self):
775
- """Get the input embeddings layer."""
776
- return self.pico_decoder.embedding_proj
777
-
778
- def set_input_embeddings(self, value):
779
- """Set the input embeddings layer."""
780
- self.pico_decoder.embedding_proj = value
781
-
782
- def get_output_embeddings(self):
783
- """Get the output embeddings layer."""
784
- return self.pico_decoder.de_embedding_proj
785
-
786
- def set_output_embeddings(self, value):
787
- """Set the output embeddings layer."""
788
- self.pico_decoder.de_embedding_proj = value
789
-
790
- def get_lm_head(self):
791
- """Get the language model head."""
792
- return self.pico_decoder.de_embedding_proj
793
-
794
- def can_generate(self) -> bool:
795
- """Check if the model can generate text."""
796
- return True
797
-
798
- @property
799
- def is_encoder_decoder(self) -> bool:
800
- """Check if the model is an encoder-decoder model."""
801
- return False
802
-
803
- @property
804
- def can_use_cache(self) -> bool:
805
- """Check if the model can use KV cache."""
806
- return True
807
-
808
- def resize_token_embeddings(
809
- self, new_num_tokens: Optional[int] = None
810
- ) -> torch.nn.Embedding:
811
- """Resize token embeddings."""
812
- old_embeddings = self.get_input_embeddings()
813
- if new_num_tokens is None:
814
- new_num_tokens = old_embeddings.num_embeddings
815
-
816
- new_embeddings = torch.nn.Embedding(
817
- new_num_tokens, old_embeddings.embedding_dim
818
- )
819
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
820
- old_embeddings.weight.data
821
- )
822
-
823
- self.pico_decoder.embedding_proj = new_embeddings
824
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
825
- old_embeddings.embedding_dim, new_num_tokens, bias=False
826
- )
827
-
828
- return new_embeddings
829
-
830
- @classmethod
831
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
832
- """
833
- Load a pretrained model from a checkpoint.
834
-
835
- This method handles loading from both the old PicoDecoderHF format and the new format.
836
- """
837
- # First try to load with the new class
838
- try:
839
- return super().from_pretrained(
840
- pretrained_model_name_or_path, *model_args, **kwargs
841
- )
842
- except Exception as e:
843
- print(f"Failed to load with new class: {e}")
844
- print("Attempting to load with legacy class and convert...")
845
-
846
- # Try to load with the old class and convert
847
- try:
848
- from transformers import AutoModel
849
-
850
- old_model = AutoModel.from_pretrained(
851
- pretrained_model_name_or_path,
852
- trust_remote_code=True,
853
- *model_args,
854
- **kwargs,
855
- )
856
-
857
- # Create new model instance
858
- new_model = cls(old_model.config)
859
-
860
- # Copy state dict
861
- new_model.load_state_dict(old_model.state_dict(), strict=False)
862
-
863
- return new_model
864
-
865
- except Exception as e2:
866
- print(f"Failed to convert from legacy format: {e2}")
867
- raise e
868
-
869
-
870
- # Register the new class
871
- PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/special_tokens_map.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "eos_token": {
3
- "content": "<|endoftext|>",
4
- "lstrip": false,
5
- "normalized": false,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "pad_token": {
10
- "content": "<|padding|>",
11
- "lstrip": false,
12
- "normalized": false,
13
- "rstrip": false,
14
- "single_word": false
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma29k/checkpoints/step_1000/tokenizer_config.json DELETED
@@ -1,239 +0,0 @@
1
- {
2
- "add_bos_token": false,
3
- "add_eos_token": false,
4
- "add_prefix_space": false,
5
- "added_tokens_decoder": {
6
- "0": {
7
- "content": "|||IP_ADDRESS|||",
8
- "lstrip": false,
9
- "normalized": true,
10
- "rstrip": false,
11
- "single_word": false,
12
- "special": false
13
- },
14
- "1": {
15
- "content": "<|padding|>",
16
- "lstrip": false,
17
- "normalized": false,
18
- "rstrip": false,
19
- "single_word": false,
20
- "special": true
21
- },
22
- "50254": {
23
- "content": " ",
24
- "lstrip": false,
25
- "normalized": true,
26
- "rstrip": false,
27
- "single_word": false,
28
- "special": false
29
- },
30
- "50255": {
31
- "content": " ",
32
- "lstrip": false,
33
- "normalized": true,
34
- "rstrip": false,
35
- "single_word": false,
36
- "special": false
37
- },
38
- "50256": {
39
- "content": " ",
40
- "lstrip": false,
41
- "normalized": true,
42
- "rstrip": false,
43
- "single_word": false,
44
- "special": false
45
- },
46
- "50257": {
47
- "content": " ",
48
- "lstrip": false,
49
- "normalized": true,
50
- "rstrip": false,
51
- "single_word": false,
52
- "special": false
53
- },
54
- "50258": {
55
- "content": " ",
56
- "lstrip": false,
57
- "normalized": true,
58
- "rstrip": false,
59
- "single_word": false,
60
- "special": false
61
- },
62
- "50259": {
63
- "content": " ",
64
- "lstrip": false,
65
- "normalized": true,
66
- "rstrip": false,
67
- "single_word": false,
68
- "special": false
69
- },
70
- "50260": {
71
- "content": " ",
72
- "lstrip": false,
73
- "normalized": true,
74
- "rstrip": false,
75
- "single_word": false,
76
- "special": false
77
- },
78
- "50261": {
79
- "content": " ",
80
- "lstrip": false,
81
- "normalized": true,
82
- "rstrip": false,
83
- "single_word": false,
84
- "special": false
85
- },
86
- "50262": {
87
- "content": " ",
88
- "lstrip": false,
89
- "normalized": true,
90
- "rstrip": false,
91
- "single_word": false,
92
- "special": false
93
- },
94
- "50263": {
95
- "content": " ",
96
- "lstrip": false,
97
- "normalized": true,
98
- "rstrip": false,
99
- "single_word": false,
100
- "special": false
101
- },
102
- "50264": {
103
- "content": " ",
104
- "lstrip": false,
105
- "normalized": true,
106
- "rstrip": false,
107
- "single_word": false,
108
- "special": false
109
- },
110
- "50265": {
111
- "content": " ",
112
- "lstrip": false,
113
- "normalized": true,
114
- "rstrip": false,
115
- "single_word": false,
116
- "special": false
117
- },
118
- "50266": {
119
- "content": " ",
120
- "lstrip": false,
121
- "normalized": true,
122
- "rstrip": false,
123
- "single_word": false,
124
- "special": false
125
- },
126
- "50267": {
127
- "content": " ",
128
- "lstrip": false,
129
- "normalized": true,
130
- "rstrip": false,
131
- "single_word": false,
132
- "special": false
133
- },
134
- "50268": {
135
- "content": " ",
136
- "lstrip": false,
137
- "normalized": true,
138
- "rstrip": false,
139
- "single_word": false,
140
- "special": false
141
- },
142
- "50269": {
143
- "content": " ",
144
- "lstrip": false,
145
- "normalized": true,
146
- "rstrip": false,
147
- "single_word": false,
148
- "special": false
149
- },
150
- "50270": {
151
- "content": " ",
152
- "lstrip": false,
153
- "normalized": true,
154
- "rstrip": false,
155
- "single_word": false,
156
- "special": false
157
- },
158
- "50271": {
159
- "content": " ",
160
- "lstrip": false,
161
- "normalized": true,
162
- "rstrip": false,
163
- "single_word": false,
164
- "special": false
165
- },
166
- "50272": {
167
- "content": " ",
168
- "lstrip": false,
169
- "normalized": true,
170
- "rstrip": false,
171
- "single_word": false,
172
- "special": false
173
- },
174
- "50273": {
175
- "content": " ",
176
- "lstrip": false,
177
- "normalized": true,
178
- "rstrip": false,
179
- "single_word": false,
180
- "special": false
181
- },
182
- "50274": {
183
- "content": " ",
184
- "lstrip": false,
185
- "normalized": true,
186
- "rstrip": false,
187
- "single_word": false,
188
- "special": false
189
- },
190
- "50275": {
191
- "content": " ",
192
- "lstrip": false,
193
- "normalized": true,
194
- "rstrip": false,
195
- "single_word": false,
196
- "special": false
197
- },
198
- "50276": {
199
- "content": " ",
200
- "lstrip": false,
201
- "normalized": true,
202
- "rstrip": false,
203
- "single_word": false,
204
- "special": false
205
- },
206
- "50277": {
207
- "content": "|||EMAIL_ADDRESS|||",
208
- "lstrip": false,
209
- "normalized": true,
210
- "rstrip": false,
211
- "single_word": false,
212
- "special": false
213
- },
214
- "50278": {
215
- "content": "|||PHONE_NUMBER|||",
216
- "lstrip": false,
217
- "normalized": true,
218
- "rstrip": false,
219
- "single_word": false,
220
- "special": false
221
- },
222
- "50279": {
223
- "content": "<|endoftext|>",
224
- "lstrip": false,
225
- "normalized": false,
226
- "rstrip": false,
227
- "single_word": false,
228
- "special": true
229
- }
230
- },
231
- "bos_token": null,
232
- "clean_up_tokenization_spaces": true,
233
- "eos_token": "<|endoftext|>",
234
- "extra_special_tokens": {},
235
- "model_max_length": 1000000000000000019884624838656,
236
- "pad_token": "<|padding|>",
237
- "tokenizer_class": "GPTNeoXTokenizer",
238
- "unk_token": null
239
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/config.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "activation_hidden_dim": 384,
3
- "architectures": [
4
- "PicoDecoderHF"
5
- ],
6
- "attention_n_heads": 12,
7
- "attention_n_kv_heads": 4,
8
- "auto_map": {
9
- "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
- "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
- },
12
- "batch_size": 1024,
13
- "d_model": 96,
14
- "max_seq_len": 2048,
15
- "model_type": "pico_decoder",
16
- "n_layers": 12,
17
- "norm_eps": 1e-06,
18
- "position_emb_theta": 10000.0,
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.48.3",
21
- "vocab_size": 50304
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/fabric_state/checkpoint.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:440c5892855af3a38d67aeaf2293293c14de741146ad01f1e4856dd59d2750fc
3
- size 135543171
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/generation_config.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "transformers_version": "4.48.3",
3
- "vocab_size": 50304
4
- }
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_activations.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ea7b999b22c58569b9e2f66325bba2781549049bbe4ae6b39e5258394145eda9
3
- size 33819
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_data/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8f9fd2479b7954b3eee4b6373b2f00eba1a7a656a510c35ddb322de4599b8602
3
- size 66592
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_data/dataset_info.json DELETED
@@ -1,19 +0,0 @@
1
- {
2
- "citation": "",
3
- "description": "",
4
- "features": {
5
- "input_ids": {
6
- "feature": {
7
- "dtype": "int32",
8
- "_type": "Value"
9
- },
10
- "_type": "Sequence"
11
- },
12
- "text": {
13
- "dtype": "string",
14
- "_type": "Value"
15
- }
16
- },
17
- "homepage": "",
18
- "license": ""
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_data/state.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "_data_files": [
3
- {
4
- "filename": "data-00000-of-00001.arrow"
5
- }
6
- ],
7
- "_fingerprint": "1e8504573fba12c8",
8
- "_format_columns": null,
9
- "_format_kwargs": {},
10
- "_format_type": null,
11
- "_output_all_columns": false,
12
- "_split": null
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_gradients.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:826d1b3b68147d4cfc9a67110850cb086ac93bedfaa90243747f94f79d212392
3
- size 2371527
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/learning_dynamics/train_weights.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6215aff1505b169aaef56f9bdfa7deffb3158a30fe3464e89611a89773d9b89c
3
- size 2371443
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cb92b1e7e65e8578ca33addec28caa9afecd658b8c37ead9e59e8446df236aad
3
- size 45143592
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/pico_decoder.py DELETED
@@ -1,871 +0,0 @@
1
- """
2
- Pico Decoder: A Lightweight Causal Transformer Language Model
3
-
4
- Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
-
6
- Everything is written with a modular design for easy modification and experimentation.
7
-
8
- Key features:
9
- - RMSNorm for layer normalization
10
- - Rotary Positional Embeddings (RoPE)
11
- - Multi-head attention with KV-cache support
12
- - SwiGLU activation function
13
- - Residual connections throughout
14
-
15
- - KV-cache for faster autoregressive generation
16
-
17
- References:
18
- - RoPE: https://arxiv.org/abs/2104.09864
19
- - SwiGLU: https://arxiv.org/abs/2002.05202
20
- - LLAMA: https://arxiv.org/abs/2302.13971
21
-
22
- Adapted from:
23
- - OLMO: https://github.com/allenai/OLMo
24
- - LLAMA: https://github.com/meta/llama
25
- """
26
-
27
- from dataclasses import asdict
28
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
-
30
- import torch
31
- import torch.nn as nn
32
- import torch.nn.functional as F
33
- from torch.nn.attention import SDPBackend, sdpa_kernel
34
- from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
35
- from transformers.generation import GenerationConfig
36
- from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
37
-
38
- try:
39
- if TYPE_CHECKING:
40
- # We need to do this to avoid importing these when creating the HF-compatible models
41
- from src.config import ModelConfig
42
- except ImportError:
43
- pass
44
-
45
- ########################################################
46
- #
47
- # Layer Normalization
48
- #
49
- ########################################################
50
-
51
-
52
- class RMSNorm(torch.nn.Module):
53
- """Root Mean Square Layer Normalization.
54
-
55
- A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
56
- resulting in improved stability and performance.
57
-
58
- Args:
59
- config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
60
- - config.norm_eps: Small constant for numerical stability
61
- - config.d_model: Model dimension for the weight parameter
62
-
63
- References:
64
- https://arxiv.org/abs/1910.07467
65
- """
66
-
67
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
68
- super().__init__()
69
- self.eps = config.norm_eps
70
- self.weight = nn.Parameter(torch.ones(config.d_model))
71
-
72
- def _norm(self, x: torch.Tensor) -> torch.Tensor:
73
- """
74
- Normalizes the input tensor by its RMS value.
75
- """
76
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
77
-
78
- def forward(self, x: torch.Tensor) -> torch.Tensor:
79
- """
80
- Applies RMS normalization to the input tensor and scales it by the weight parameter.
81
- """
82
- output = self._norm(x.float()).type_as(x)
83
- return output * self.weight
84
-
85
-
86
- ########################################################
87
- #
88
- # Positional Embedding
89
- #
90
- ########################################################
91
-
92
-
93
- class RoPE(nn.Module):
94
- """Rotary Positional Embeddings (RoPE).
95
-
96
- Implements position-dependent rotation of keys and queries in attention mechanism,
97
- allowing better modeling of relative positions in sequences. Uses complex number
98
- operations for efficient rotation.
99
-
100
- Args:
101
- config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
102
- - config.position_emb_theta: Base for frequency computation
103
- - config.d_model: Model dimension
104
- - config.attention_n_heads: Number of attention heads
105
- - config.max_seq_len: Maximum sequence length
106
-
107
- References:
108
- https://arxiv.org/abs/2104.09864
109
- """
110
-
111
- _freqs_cis_tensor: torch.Tensor | None = None
112
-
113
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
114
- super().__init__()
115
-
116
- self.theta = config.position_emb_theta
117
- self.dim = config.d_model // config.attention_n_heads
118
-
119
- max_seq_len = config.max_seq_len
120
-
121
- # only gets set once, and then reused for all RoPE instances
122
- if RoPE._freqs_cis_tensor is None:
123
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(
124
- max_seq_len, self.theta, self.dim
125
- )
126
-
127
- # register _freqs_cis buffer
128
- # can be easily recomputed so persistent=False
129
- self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
130
-
131
- @classmethod
132
- def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
133
- """Setup Frequency Tensor for RoPE Embeddings
134
-
135
- Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
136
-
137
- Note other implementations will use cos and sin directly, but using the complex
138
- number representation is (probably) more efficient:
139
-
140
- e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
141
- """
142
- _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
143
- positions = torch.arange(seq_len)
144
- freqs = torch.outer(positions, _freqs)
145
- return torch.polar(torch.ones_like(freqs), freqs) # complex64
146
-
147
- def get_freqs_cis(
148
- self, input_shape: torch.Size, start_pos: int, end_pos: int
149
- ) -> torch.Tensor:
150
- """Reshape Frequency Tensor for RoPE Embeddings
151
-
152
- Makes the frequency tensor broadcastable with the input tensor.
153
- """
154
- _freqs_cis = self._freqs_cis[start_pos:end_pos]
155
- ndim = len(input_shape)
156
- assert 0 <= 1 < ndim
157
- assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
158
-
159
- # TODO: Check whether this is correct (might be able to remove this)
160
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
161
- return _freqs_cis.view(*shape)
162
-
163
- def forward(
164
- self,
165
- queries: torch.Tensor,
166
- keys: torch.Tensor,
167
- start_pos: int = 0,
168
- ) -> Tuple[torch.Tensor, torch.Tensor]:
169
- """Apply RoPE Embeddings to Queries and Keys
170
-
171
- Applies the rotary positional embeddings to the input tensors via complex num multiplication
172
-
173
- NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
174
- """
175
- queries_ = torch.view_as_complex(
176
- queries.float().reshape(*queries.shape[:-1], -1, 2)
177
- )
178
- keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
179
-
180
- input_shape = (
181
- queries_.shape
182
- ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
183
- freqs_start_pos = start_pos
184
- freqs_end_pos = freqs_start_pos + queries_.shape[1]
185
-
186
- freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
187
-
188
- queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
189
- keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
190
- return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
191
-
192
-
193
- ########################################################
194
- #
195
- # Attention
196
- #
197
- ########################################################
198
-
199
-
200
- class Attention(nn.Module):
201
- """Multi-head Attention with Group Query Attention support.
202
-
203
- Implements scaled dot-product attention and supports:
204
- - Grouped Query Attention (GQA)
205
- - Key-Value caching for efficient inference
206
- - RoPE integration
207
-
208
- Args:
209
- config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
210
- - config.attention_n_heads: Number of attention heads
211
- - config.attention_n_kv_heads: Number of key/value heads
212
- - config.d_model: Model dimension
213
- - config.batch_size: Maximum batch size
214
- - config.max_seq_len: Maximum sequence length
215
-
216
- Shape:
217
- - Input: (batch_size, seq_len, d_model)
218
- - Output: (batch_size, seq_len, d_model)
219
- """
220
-
221
- def __init__(
222
- self,
223
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
224
- ):
225
- super().__init__()
226
-
227
- self.n_heads = config.attention_n_heads
228
- self.n_kv_heads = config.attention_n_kv_heads
229
-
230
- self.batch_size = config.batch_size
231
- self.max_seq_len = config.max_seq_len
232
-
233
- d_model = config.d_model
234
- self.head_dim = d_model // self.n_heads
235
-
236
- self.n_rep = self.n_heads // self.n_kv_heads
237
-
238
- self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
239
- self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
240
- self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
241
- self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
242
-
243
- self.rope = RoPE(config)
244
-
245
- def forward(
246
- self,
247
- input: torch.Tensor,
248
- mask: Optional[torch.Tensor] = None,
249
- past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
250
- use_cache: bool = False,
251
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
252
- """Forward pass for the attention mechanism.
253
-
254
- Computes queries, keys, and values for the attention mechanism. Applies rotary positional
255
- embeddings to the queries and keys, and then computes attention scores and outputs.
256
-
257
- For an introduction to the attention mechanism, see:
258
- https://arxiv.org/abs/1706.03762
259
-
260
- A few things to note:
261
- - The past_key_values is used to implement the KV cache, which is used to speed up
262
- generation by caching the KV pairs from previous forward passes. This is useful when doing
263
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
264
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
265
- its own KV cache - this KV cache is implemented as a tuple.
266
- """
267
- bsz, seq_len, _ = input.shape
268
- _queries, _keys, _values = (
269
- self.q_proj(input),
270
- self.k_proj(input),
271
- self.v_proj(input),
272
- )
273
-
274
- # Reshaping for multi-head attention
275
- queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
276
- keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
277
- values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
278
-
279
- # The start position is used to apply the RoPE embeddings to only the new tokens
280
- # when using the kv_cache in the attention mechanism.
281
- # We want to start from the last position in the cache.
282
- start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
283
-
284
- # apply rotary positional embeddings
285
- queries, keys = self.rope(queries, keys, start_pos)
286
-
287
- if past_key_values is not None:
288
- keys = torch.cat([past_key_values[0], keys], dim=1)
289
- values = torch.cat([past_key_values[1], values], dim=1)
290
-
291
- if use_cache:
292
- cached_keys = keys
293
- cached_values = values
294
- else:
295
- cached_keys = None
296
- cached_values = None
297
-
298
- queries = queries.transpose(1, 2)
299
- keys = keys.transpose(1, 2)
300
- values = values.transpose(1, 2)
301
-
302
- apply_gqa = self.n_rep > 1
303
- if apply_gqa and queries.device.type == "mps":
304
- # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
305
- # outside of the kernel to get the same effect.
306
- # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
307
- keys = keys.repeat_interleave(self.n_rep, dim=-3)
308
- values = values.repeat_interleave(self.n_rep, dim=-3)
309
- apply_gqa = False
310
-
311
- backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
312
-
313
- with sdpa_kernel(backends=backends):
314
- attn_output = F.scaled_dot_product_attention(
315
- queries.contiguous(),
316
- keys.contiguous(),
317
- values.contiguous(),
318
- attn_mask=mask.to(queries.dtype) if mask is not None else None,
319
- enable_gqa=apply_gqa,
320
- )
321
-
322
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
323
- output = self.o_proj(attn_output)
324
-
325
- return output, (cached_keys, cached_values)
326
-
327
-
328
- ########################################################
329
- #
330
- # SwiGLU (Combines MLP and Activation)
331
- #
332
- ########################################################
333
-
334
-
335
- class SwiGLU(nn.Module):
336
- """SwiGLU Activation Function with Linear Projections.
337
-
338
- Implements the SwiGLU activation function combined with linear transformations,
339
- serving as the feed-forward network in transformer blocks.
340
-
341
- Args:
342
- config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
343
- - config.d_model: Model dimension
344
- - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
345
-
346
- References:
347
- https://arxiv.org/abs/2002.05202
348
- """
349
-
350
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
351
- super().__init__()
352
-
353
- model_dim = config.d_model
354
- act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
355
-
356
- self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
357
- self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
358
- self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
359
-
360
- def forward(self, x: torch.Tensor) -> torch.Tensor:
361
- return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
362
-
363
-
364
- ########################################################
365
- #
366
- # PicoDecoderBlock
367
- #
368
- ########################################################
369
-
370
-
371
- class PicoDecoderBlock(nn.Module):
372
- """Single Transformer Block with Attention and Feed-forward layers.
373
-
374
- Implements a standard transformer block with:
375
- - Multi-head attention with normalization and residual connection
376
- - SwiGLU feed-forward network with normalization and residual connection
377
-
378
- Args:
379
- config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
380
- a HuggingFace PicoDecoderHFConfig
381
- """
382
-
383
- def __init__(
384
- self,
385
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
386
- ):
387
- super().__init__()
388
-
389
- self.attention = Attention(config)
390
- self.swiglu = SwiGLU(config)
391
- self.attention_norm = RMSNorm(config)
392
- self.swiglu_norm = RMSNorm(config)
393
-
394
- def forward(
395
- self,
396
- input: torch.Tensor,
397
- mask: Optional[torch.Tensor] = None,
398
- past_key_values: Optional[Tuple[torch.Tensor]] = None,
399
- use_cache: bool = False,
400
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
401
- attention_output, cached_key_values = self.attention(
402
- self.attention_norm(input),
403
- mask=mask,
404
- past_key_values=past_key_values,
405
- use_cache=use_cache,
406
- )
407
- # NOTE: cached_key_values is None if use_cache is False
408
-
409
- h = input + attention_output
410
- out = h + self.swiglu(self.swiglu_norm(h))
411
- return out, cached_key_values
412
-
413
-
414
- ########################################################
415
- #
416
- # Pico Decoder (Causal Transformer Model)
417
- #
418
- ########################################################
419
-
420
-
421
- class PicoDecoder(nn.Module):
422
- """
423
- Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
424
- single autoregressive model.
425
-
426
- For more information on the model, see the classes for the modules that make up the model.
427
- """
428
-
429
- def __init__(
430
- self,
431
- model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
432
- ):
433
- super().__init__()
434
- self.config = model_config
435
-
436
- self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
437
- self.layers = nn.ModuleList(
438
- [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
439
- )
440
- self.output_norm = RMSNorm(self.config)
441
- self.de_embedding_proj = nn.Linear(
442
- self.config.d_model, self.config.vocab_size, bias=False
443
- )
444
-
445
- def convert_to_hf_model(self) -> "PicoDecoderHF":
446
- """Convert the Lightning model to a HuggingFace model."""
447
- # Create HF config without fabric-specific settings
448
- hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
449
-
450
- # Create new HF model
451
- hf_model = PicoDecoderHF(hf_config)
452
-
453
- # Copy state dict, excluding fabric-specific keys
454
- hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
455
-
456
- return hf_model
457
-
458
- def forward(
459
- self,
460
- input_ids: torch.Tensor,
461
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
462
- use_cache: bool = False,
463
- ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
464
- """
465
- This is the forward pass for the entire Pico model. It boils down to:
466
- - Embedding the input ids
467
- - Creating a causal mask
468
- - Processing through the pico layers
469
- - Projecting the output to logits
470
-
471
- NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
472
- generation by caching the KV pairs from previous forward passes. This is useful when doing
473
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
474
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
475
- its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
476
- KV caches (so a tuple of tuples).
477
- """
478
-
479
- seq_len = input_ids.shape[-1]
480
- h = self.embedding_proj(input_ids)
481
-
482
- # Calculate start position from past cached KV pairs. Remember that each layer has its
483
- # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
484
- # correct layer and then for either the keys or values.
485
- start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
486
-
487
- # Create causal mask for current sequence
488
- mask = None
489
- if seq_len > 1:
490
- mask = torch.full((seq_len, seq_len), float("-inf"))
491
- mask = torch.triu(mask, diagonal=1)
492
-
493
- # If using KV cache, extend mask to cover cached sequence length
494
- if past_key_values is not None:
495
- # Add zeros for cached tokens (we can attend to all of them)
496
- mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
497
-
498
- mask = mask.to(h.device)
499
-
500
- # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
501
- # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
502
- cached_key_values = () if use_cache else None
503
-
504
- # Process through transformer blocks
505
- for idx, layer in enumerate(self.layers):
506
- layer_past_key_values = (
507
- past_key_values[idx] if past_key_values is not None else None
508
- )
509
-
510
- h, layer_cached_key_values = layer(
511
- h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
512
- )
513
-
514
- if use_cache:
515
- cached_key_values += (layer_cached_key_values,)
516
-
517
- # Final norm and projection
518
- h = self.output_norm(h)
519
- logits = self.de_embedding_proj(h).float()
520
-
521
- return logits, cached_key_values
522
-
523
-
524
- ########################################################
525
- #
526
- # HuggingFace Wrapper for the Pico Decoder model.
527
- #
528
- ########################################################
529
-
530
-
531
- class PicoDecoderHFConfig(PretrainedConfig):
532
- """Config class for the Pico Decoder HuggingFace wrapper."""
533
-
534
- model_type = "pico_decoder"
535
-
536
- @classmethod
537
- def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
538
- """
539
- Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
540
- this is because with some kwargs special handling is required and can make this class
541
- brittle.
542
- """
543
- pico_config = cls(**config_dict)
544
-
545
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
546
- unused_kwargs = {
547
- key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
548
- }
549
-
550
- if return_unused_kwargs:
551
- return pico_config, unused_kwargs
552
- return pico_config
553
-
554
- @classmethod
555
- def from_dataclass(cls, model_config: "ModelConfig"):
556
- """Initialise from our custom config dataclass."""
557
- return cls.from_dict(asdict(model_config))
558
-
559
-
560
- class PicoDecoderHF(PreTrainedModel, GenerationMixin):
561
- """
562
- HuggingFace wrapper for the Pico model with generation support.
563
-
564
- Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
565
- wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
566
- Pico model as well as the model wrapped in this HuggingFace class.
567
-
568
- This also lets you do cool things like:
569
-
570
- `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
571
- """
572
-
573
- config_class = PicoDecoderHFConfig
574
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
575
- main_input_name = "input_ids"
576
-
577
- def __init__(self, config: PicoDecoderHFConfig):
578
- super().__init__(config)
579
- self.pico_decoder = PicoDecoder(config)
580
- # Initialize generation config with defaults
581
- self.generation_config = GenerationConfig()
582
- # Set some reasonable defaults for the model
583
- if hasattr(config, "max_position_embeddings"):
584
- self.generation_config.max_length = config.max_position_embeddings
585
- if hasattr(config, "vocab_size"):
586
- self.generation_config.vocab_size = config.vocab_size
587
-
588
- def forward(
589
- self,
590
- input_ids: torch.Tensor,
591
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
592
- use_cache: bool = False,
593
- **kwargs,
594
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
595
- """HuggingFace forward pass wrapper.
596
-
597
- Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
598
- Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
599
- """
600
- logits, past_key_values = self.pico_decoder(
601
- input_ids, past_key_values, use_cache
602
- )
603
- if use_cache:
604
- return CausalLMOutputWithPast(
605
- logits=logits,
606
- past_key_values=past_key_values,
607
- )
608
- else:
609
- return CausalLMOutput(
610
- logits=logits,
611
- )
612
-
613
- def prepare_inputs_for_generation(
614
- self,
615
- input_ids: torch.LongTensor,
616
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
617
- attention_mask: Optional[torch.LongTensor] = None,
618
- **kwargs,
619
- ) -> Dict[str, Any]:
620
- """
621
- Prepare inputs for generation.
622
-
623
- Args:
624
- input_ids: Input token IDs
625
- past_key_values: Cached key-value pairs from previous forward passes
626
- attention_mask: Attention mask for the input
627
- **kwargs: Additional arguments
628
-
629
- Returns:
630
- Dictionary containing prepared inputs
631
- """
632
- # If we have past_key_values, we only need the last token
633
- if past_key_values is not None:
634
- input_ids = input_ids[:, -1:]
635
-
636
- return {
637
- "input_ids": input_ids,
638
- "past_key_values": past_key_values,
639
- "use_cache": True,
640
- }
641
-
642
- def get_input_embeddings(self):
643
- """Get the input embeddings layer."""
644
- return self.pico_decoder.embedding_proj
645
-
646
- def set_input_embeddings(self, value):
647
- """Set the input embeddings layer."""
648
- self.pico_decoder.embedding_proj = value
649
-
650
- def get_output_embeddings(self):
651
- """Get the output embeddings layer."""
652
- return self.pico_decoder.de_embedding_proj
653
-
654
- def set_output_embeddings(self, value):
655
- """Set the output embeddings layer."""
656
- self.pico_decoder.de_embedding_proj = value
657
-
658
- def get_lm_head(self):
659
- """Get the language model head."""
660
- return self.pico_decoder.de_embedding_proj
661
-
662
- def can_generate(self) -> bool:
663
- """Check if the model can generate text."""
664
- return True
665
-
666
- @property
667
- def is_encoder_decoder(self) -> bool:
668
- """Check if the model is an encoder-decoder model."""
669
- return False
670
-
671
- @property
672
- def can_use_cache(self) -> bool:
673
- """Check if the model can use KV cache."""
674
- return True
675
-
676
- def resize_token_embeddings(
677
- self, new_num_tokens: Optional[int] = None
678
- ) -> torch.nn.Embedding:
679
- """Resize token embeddings."""
680
- old_embeddings = self.get_input_embeddings()
681
- if new_num_tokens is None:
682
- new_num_tokens = old_embeddings.num_embeddings
683
-
684
- new_embeddings = torch.nn.Embedding(
685
- new_num_tokens, old_embeddings.embedding_dim
686
- )
687
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
688
- old_embeddings.weight.data
689
- )
690
-
691
- self.pico_decoder.embedding_proj = new_embeddings
692
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
693
- old_embeddings.embedding_dim, new_num_tokens, bias=False
694
- )
695
-
696
- return new_embeddings
697
-
698
-
699
- # Register for auto classes
700
- PicoDecoderHFConfig.register_for_auto_class()
701
- PicoDecoderHF.register_for_auto_class("AutoModel")
702
- PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
703
-
704
-
705
- ########################################################
706
- #
707
- # New PicoDecoderForCausalLM class for generation support
708
- #
709
- ########################################################
710
-
711
-
712
- class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
713
- """
714
- PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
715
-
716
- This class is designed to work with existing checkpoints and provides full generation support.
717
- It inherits from the right base classes that HuggingFace expects for text generation.
718
- """
719
-
720
- config_class = PicoDecoderHFConfig
721
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
722
- main_input_name = "input_ids"
723
-
724
- def __init__(self, config: PicoDecoderHFConfig):
725
- super().__init__(config)
726
- self.pico_decoder = PicoDecoder(config)
727
- # Initialize generation config with defaults
728
- self.generation_config = GenerationConfig()
729
- # Set some reasonable defaults for the model
730
- if hasattr(config, "max_position_embeddings"):
731
- self.generation_config.max_length = config.max_position_embeddings
732
- if hasattr(config, "vocab_size"):
733
- self.generation_config.vocab_size = config.vocab_size
734
-
735
- def forward(
736
- self,
737
- input_ids: torch.Tensor,
738
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
739
- use_cache: bool = False,
740
- **kwargs,
741
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
742
- """Forward pass for text generation."""
743
- logits, past_key_values = self.pico_decoder(
744
- input_ids, past_key_values, use_cache
745
- )
746
- if use_cache:
747
- return CausalLMOutputWithPast(
748
- logits=logits,
749
- past_key_values=past_key_values,
750
- )
751
- else:
752
- return CausalLMOutput(
753
- logits=logits,
754
- )
755
-
756
- def prepare_inputs_for_generation(
757
- self,
758
- input_ids: torch.LongTensor,
759
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
760
- attention_mask: Optional[torch.LongTensor] = None,
761
- **kwargs,
762
- ) -> Dict[str, Any]:
763
- """Prepare inputs for generation."""
764
- # If we have past_key_values, we only need the last token
765
- if past_key_values is not None:
766
- input_ids = input_ids[:, -1:]
767
-
768
- return {
769
- "input_ids": input_ids,
770
- "past_key_values": past_key_values,
771
- "use_cache": True,
772
- }
773
-
774
- def get_input_embeddings(self):
775
- """Get the input embeddings layer."""
776
- return self.pico_decoder.embedding_proj
777
-
778
- def set_input_embeddings(self, value):
779
- """Set the input embeddings layer."""
780
- self.pico_decoder.embedding_proj = value
781
-
782
- def get_output_embeddings(self):
783
- """Get the output embeddings layer."""
784
- return self.pico_decoder.de_embedding_proj
785
-
786
- def set_output_embeddings(self, value):
787
- """Set the output embeddings layer."""
788
- self.pico_decoder.de_embedding_proj = value
789
-
790
- def get_lm_head(self):
791
- """Get the language model head."""
792
- return self.pico_decoder.de_embedding_proj
793
-
794
- def can_generate(self) -> bool:
795
- """Check if the model can generate text."""
796
- return True
797
-
798
- @property
799
- def is_encoder_decoder(self) -> bool:
800
- """Check if the model is an encoder-decoder model."""
801
- return False
802
-
803
- @property
804
- def can_use_cache(self) -> bool:
805
- """Check if the model can use KV cache."""
806
- return True
807
-
808
- def resize_token_embeddings(
809
- self, new_num_tokens: Optional[int] = None
810
- ) -> torch.nn.Embedding:
811
- """Resize token embeddings."""
812
- old_embeddings = self.get_input_embeddings()
813
- if new_num_tokens is None:
814
- new_num_tokens = old_embeddings.num_embeddings
815
-
816
- new_embeddings = torch.nn.Embedding(
817
- new_num_tokens, old_embeddings.embedding_dim
818
- )
819
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
820
- old_embeddings.weight.data
821
- )
822
-
823
- self.pico_decoder.embedding_proj = new_embeddings
824
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
825
- old_embeddings.embedding_dim, new_num_tokens, bias=False
826
- )
827
-
828
- return new_embeddings
829
-
830
- @classmethod
831
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
832
- """
833
- Load a pretrained model from a checkpoint.
834
-
835
- This method handles loading from both the old PicoDecoderHF format and the new format.
836
- """
837
- # First try to load with the new class
838
- try:
839
- return super().from_pretrained(
840
- pretrained_model_name_or_path, *model_args, **kwargs
841
- )
842
- except Exception as e:
843
- print(f"Failed to load with new class: {e}")
844
- print("Attempting to load with legacy class and convert...")
845
-
846
- # Try to load with the old class and convert
847
- try:
848
- from transformers import AutoModel
849
-
850
- old_model = AutoModel.from_pretrained(
851
- pretrained_model_name_or_path,
852
- trust_remote_code=True,
853
- *model_args,
854
- **kwargs,
855
- )
856
-
857
- # Create new model instance
858
- new_model = cls(old_model.config)
859
-
860
- # Copy state dict
861
- new_model.load_state_dict(old_model.state_dict(), strict=False)
862
-
863
- return new_model
864
-
865
- except Exception as e2:
866
- print(f"Failed to convert from legacy format: {e2}")
867
- raise e
868
-
869
-
870
- # Register the new class
871
- PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/special_tokens_map.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "eos_token": {
3
- "content": "<|endoftext|>",
4
- "lstrip": false,
5
- "normalized": false,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "pad_token": {
10
- "content": "<|padding|>",
11
- "lstrip": false,
12
- "normalized": false,
13
- "rstrip": false,
14
- "single_word": false
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma29k/checkpoints/step_2000/tokenizer_config.json DELETED
@@ -1,239 +0,0 @@
1
- {
2
- "add_bos_token": false,
3
- "add_eos_token": false,
4
- "add_prefix_space": false,
5
- "added_tokens_decoder": {
6
- "0": {
7
- "content": "|||IP_ADDRESS|||",
8
- "lstrip": false,
9
- "normalized": true,
10
- "rstrip": false,
11
- "single_word": false,
12
- "special": false
13
- },
14
- "1": {
15
- "content": "<|padding|>",
16
- "lstrip": false,
17
- "normalized": false,
18
- "rstrip": false,
19
- "single_word": false,
20
- "special": true
21
- },
22
- "50254": {
23
- "content": " ",
24
- "lstrip": false,
25
- "normalized": true,
26
- "rstrip": false,
27
- "single_word": false,
28
- "special": false
29
- },
30
- "50255": {
31
- "content": " ",
32
- "lstrip": false,
33
- "normalized": true,
34
- "rstrip": false,
35
- "single_word": false,
36
- "special": false
37
- },
38
- "50256": {
39
- "content": " ",
40
- "lstrip": false,
41
- "normalized": true,
42
- "rstrip": false,
43
- "single_word": false,
44
- "special": false
45
- },
46
- "50257": {
47
- "content": " ",
48
- "lstrip": false,
49
- "normalized": true,
50
- "rstrip": false,
51
- "single_word": false,
52
- "special": false
53
- },
54
- "50258": {
55
- "content": " ",
56
- "lstrip": false,
57
- "normalized": true,
58
- "rstrip": false,
59
- "single_word": false,
60
- "special": false
61
- },
62
- "50259": {
63
- "content": " ",
64
- "lstrip": false,
65
- "normalized": true,
66
- "rstrip": false,
67
- "single_word": false,
68
- "special": false
69
- },
70
- "50260": {
71
- "content": " ",
72
- "lstrip": false,
73
- "normalized": true,
74
- "rstrip": false,
75
- "single_word": false,
76
- "special": false
77
- },
78
- "50261": {
79
- "content": " ",
80
- "lstrip": false,
81
- "normalized": true,
82
- "rstrip": false,
83
- "single_word": false,
84
- "special": false
85
- },
86
- "50262": {
87
- "content": " ",
88
- "lstrip": false,
89
- "normalized": true,
90
- "rstrip": false,
91
- "single_word": false,
92
- "special": false
93
- },
94
- "50263": {
95
- "content": " ",
96
- "lstrip": false,
97
- "normalized": true,
98
- "rstrip": false,
99
- "single_word": false,
100
- "special": false
101
- },
102
- "50264": {
103
- "content": " ",
104
- "lstrip": false,
105
- "normalized": true,
106
- "rstrip": false,
107
- "single_word": false,
108
- "special": false
109
- },
110
- "50265": {
111
- "content": " ",
112
- "lstrip": false,
113
- "normalized": true,
114
- "rstrip": false,
115
- "single_word": false,
116
- "special": false
117
- },
118
- "50266": {
119
- "content": " ",
120
- "lstrip": false,
121
- "normalized": true,
122
- "rstrip": false,
123
- "single_word": false,
124
- "special": false
125
- },
126
- "50267": {
127
- "content": " ",
128
- "lstrip": false,
129
- "normalized": true,
130
- "rstrip": false,
131
- "single_word": false,
132
- "special": false
133
- },
134
- "50268": {
135
- "content": " ",
136
- "lstrip": false,
137
- "normalized": true,
138
- "rstrip": false,
139
- "single_word": false,
140
- "special": false
141
- },
142
- "50269": {
143
- "content": " ",
144
- "lstrip": false,
145
- "normalized": true,
146
- "rstrip": false,
147
- "single_word": false,
148
- "special": false
149
- },
150
- "50270": {
151
- "content": " ",
152
- "lstrip": false,
153
- "normalized": true,
154
- "rstrip": false,
155
- "single_word": false,
156
- "special": false
157
- },
158
- "50271": {
159
- "content": " ",
160
- "lstrip": false,
161
- "normalized": true,
162
- "rstrip": false,
163
- "single_word": false,
164
- "special": false
165
- },
166
- "50272": {
167
- "content": " ",
168
- "lstrip": false,
169
- "normalized": true,
170
- "rstrip": false,
171
- "single_word": false,
172
- "special": false
173
- },
174
- "50273": {
175
- "content": " ",
176
- "lstrip": false,
177
- "normalized": true,
178
- "rstrip": false,
179
- "single_word": false,
180
- "special": false
181
- },
182
- "50274": {
183
- "content": " ",
184
- "lstrip": false,
185
- "normalized": true,
186
- "rstrip": false,
187
- "single_word": false,
188
- "special": false
189
- },
190
- "50275": {
191
- "content": " ",
192
- "lstrip": false,
193
- "normalized": true,
194
- "rstrip": false,
195
- "single_word": false,
196
- "special": false
197
- },
198
- "50276": {
199
- "content": " ",
200
- "lstrip": false,
201
- "normalized": true,
202
- "rstrip": false,
203
- "single_word": false,
204
- "special": false
205
- },
206
- "50277": {
207
- "content": "|||EMAIL_ADDRESS|||",
208
- "lstrip": false,
209
- "normalized": true,
210
- "rstrip": false,
211
- "single_word": false,
212
- "special": false
213
- },
214
- "50278": {
215
- "content": "|||PHONE_NUMBER|||",
216
- "lstrip": false,
217
- "normalized": true,
218
- "rstrip": false,
219
- "single_word": false,
220
- "special": false
221
- },
222
- "50279": {
223
- "content": "<|endoftext|>",
224
- "lstrip": false,
225
- "normalized": false,
226
- "rstrip": false,
227
- "single_word": false,
228
- "special": true
229
- }
230
- },
231
- "bos_token": null,
232
- "clean_up_tokenization_spaces": true,
233
- "eos_token": "<|endoftext|>",
234
- "extra_special_tokens": {},
235
- "model_max_length": 1000000000000000019884624838656,
236
- "pad_token": "<|padding|>",
237
- "tokenizer_class": "GPTNeoXTokenizer",
238
- "unk_token": null
239
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_3000/config.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "activation_hidden_dim": 384,
3
- "architectures": [
4
- "PicoDecoderHF"
5
- ],
6
- "attention_n_heads": 12,
7
- "attention_n_kv_heads": 4,
8
- "auto_map": {
9
- "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
- "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
- },
12
- "batch_size": 1024,
13
- "d_model": 96,
14
- "max_seq_len": 2048,
15
- "model_type": "pico_decoder",
16
- "n_layers": 12,
17
- "norm_eps": 1e-06,
18
- "position_emb_theta": 10000.0,
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.48.3",
21
- "vocab_size": 50304
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_3000/fabric_state/checkpoint.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0668b8c42202253274a9483e30e3819cf2f434dd4f7ce8a48b873c6411572ac5
3
- size 135543171
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_3000/generation_config.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "transformers_version": "4.48.3",
3
- "vocab_size": 50304
4
- }
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_3000/learning_dynamics/train_activations.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4c2ef66518e08d73be5b8c1accfbf88aa5332f0082919bc5e490cafcfdca7a05
3
- size 33819
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_3000/learning_dynamics/train_data/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:15b5c7823a80674e87e5116cd5f9c48e87478c45a391f5b730c9a5cc2886f0da
3
- size 66384
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_3000/learning_dynamics/train_data/dataset_info.json DELETED
@@ -1,19 +0,0 @@
1
- {
2
- "citation": "",
3
- "description": "",
4
- "features": {
5
- "input_ids": {
6
- "feature": {
7
- "dtype": "int32",
8
- "_type": "Value"
9
- },
10
- "_type": "Sequence"
11
- },
12
- "text": {
13
- "dtype": "string",
14
- "_type": "Value"
15
- }
16
- },
17
- "homepage": "",
18
- "license": ""
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_3000/learning_dynamics/train_data/state.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "_data_files": [
3
- {
4
- "filename": "data-00000-of-00001.arrow"
5
- }
6
- ],
7
- "_fingerprint": "1931fa1b9cbde22e",
8
- "_format_columns": null,
9
- "_format_kwargs": {},
10
- "_format_type": null,
11
- "_output_all_columns": false,
12
- "_split": null
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pico-decoder-tiny-dolma29k/checkpoints/step_3000/learning_dynamics/train_gradients.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:afbcb41dd16e3fb023b33225165ce4c7690b80fbece730d1a254cf0d87b366c8
3
- size 2371527