ThomasTheMaker commited on
Commit
a1a7208
·
verified ·
1 Parent(s): 20387e9

Upload folder using huggingface_hub

Browse files
Files changed (32) hide show
  1. .gitattributes +10 -0
  2. pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/fabric_state/checkpoint.pt +1 -1
  3. pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_activations.pt +0 -0
  4. pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_data/data-00000-of-00001.arrow +3 -0
  5. pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_data/dataset_info.json +19 -0
  6. pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_data/state.json +13 -0
  7. pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_gradients.pt +3 -0
  8. pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_weights.pt +3 -0
  9. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/config.json +22 -0
  10. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/fabric_state/checkpoint.pt +3 -0
  11. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/generation_config.json +4 -0
  12. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_activations.pt +0 -0
  13. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_data/data-00000-of-00001.arrow +3 -0
  14. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_data/dataset_info.json +19 -0
  15. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_data/state.json +13 -0
  16. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_gradients.pt +3 -0
  17. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_weights.pt +3 -0
  18. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/model.safetensors +3 -0
  19. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/pico_decoder.py +911 -0
  20. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/special_tokens_map.json +16 -0
  21. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/tokenizer.json +0 -0
  22. pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/tokenizer_config.json +239 -0
  23. pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/config.json +22 -0
  24. pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/fabric_state/checkpoint.pt +3 -0
  25. pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/generation_config.json +4 -0
  26. pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/model.safetensors +3 -0
  27. pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/pico_decoder.py +911 -0
  28. pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/special_tokens_map.json +16 -0
  29. pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/tokenizer.json +0 -0
  30. pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/tokenizer_config.json +239 -0
  31. pico-decoder-tiny-dolma250M-v1/eval_results/step_102000.json +1 -0
  32. pico-decoder-tiny-dolma250M-v1/logs/log_20250831_162326.log +269 -0
.gitattributes CHANGED
@@ -982,3 +982,13 @@ pico-decoder-tiny-dolma250M-v1/checkpoints/step_98000/learning_dynamics/train_da
982
  pico-decoder-tiny-dolma250M-v1/checkpoints/step_98000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
983
  pico-decoder-tiny-dolma250M-v1/checkpoints/step_98000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
984
  pico-decoder-tiny-dolma250M-v1/checkpoints/step_98000/model.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
982
  pico-decoder-tiny-dolma250M-v1/checkpoints/step_98000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
983
  pico-decoder-tiny-dolma250M-v1/checkpoints/step_98000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
984
  pico-decoder-tiny-dolma250M-v1/checkpoints/step_98000/model.safetensors filter=lfs diff=lfs merge=lfs -text
985
+ pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
986
+ pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
987
+ pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
988
+ pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
989
+ pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_data/data-00000-of-00001.arrow filter=lfs diff=lfs merge=lfs -text
990
+ pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_gradients.pt filter=lfs diff=lfs merge=lfs -text
991
+ pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_weights.pt filter=lfs diff=lfs merge=lfs -text
992
+ pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/model.safetensors filter=lfs diff=lfs merge=lfs -text
993
+ pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/fabric_state/checkpoint.pt filter=lfs diff=lfs merge=lfs -text
994
+ pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/model.safetensors filter=lfs diff=lfs merge=lfs -text
pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/fabric_state/checkpoint.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a8e38fba8b39bdce89550461657b3d2c12715e0529740b50487e5b382b7e31b
3
  size 135543171
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7821832b1e1fcd6692940396e6edc50fce443f493f232c38e245da9561676d9a
3
  size 135543171
pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_activations.pt ADDED
Binary file (98.3 kB). View file
 
pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_data/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c563e1b5b21a23ec6c9e50ea1a3ff547984bf10de5016d87b310deb6c2d7b333
3
+ size 276480
pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_data/dataset_info.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "text": {
13
+ "dtype": "string",
14
+ "_type": "Value"
15
+ }
16
+ },
17
+ "homepage": "",
18
+ "license": ""
19
+ }
pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_data/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "0f66378a2401a0b7",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_gradients.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:852bbfeb8c97f09cadb3d7da918a6e72dd2810785847b156b295731b5233a900
3
+ size 2371527
pico-decoder-tiny-dolma250M-v1/checkpoints/step_100000/learning_dynamics/train_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca1c52d716f2f6090ed01b61b7d81a6b890637239c04ca310e52d23845963cbc
3
+ size 2371443
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_hidden_dim": 384,
3
+ "architectures": [
4
+ "PicoDecoderHF"
5
+ ],
6
+ "attention_n_heads": 12,
7
+ "attention_n_kv_heads": 4,
8
+ "auto_map": {
9
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
+ },
12
+ "batch_size": 1024,
13
+ "d_model": 96,
14
+ "max_seq_len": 2048,
15
+ "model_type": "pico_decoder",
16
+ "n_layers": 12,
17
+ "norm_eps": 1e-06,
18
+ "position_emb_theta": 10000.0,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.48.3",
21
+ "vocab_size": 50304
22
+ }
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/fabric_state/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e03faf348f9fe044370ed46fe8fb6d144ea32db9df87b30d97884ad379bd07db
3
+ size 135543171
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "transformers_version": "4.48.3",
3
+ "vocab_size": 50304
4
+ }
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_activations.pt ADDED
Binary file (98.3 kB). View file
 
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_data/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:257002d56c53426ecb933f4a65f1b66cb0f13a179a31844da13d1792ffab80c9
3
+ size 278184
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_data/dataset_info.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "text": {
13
+ "dtype": "string",
14
+ "_type": "Value"
15
+ }
16
+ },
17
+ "homepage": "",
18
+ "license": ""
19
+ }
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_data/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "f1628a2d831f3cda",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_gradients.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7266fedcfa4f2a3054badf1b4378fbbef9dac915dd694f27b2f6458ea363ac0c
3
+ size 2371527
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/learning_dynamics/train_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3dab45ffb82880e9473f5b219a8fd7a1fbe79bcce195839f43072a00807e98b
3
+ size 2371443
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99bb8d1781b101d5c6271107090b8d643ebf4c7cfc56d2b92e8a7f50902f916a
3
+ size 45143592
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/pico_decoder.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Decoder: A Lightweight Causal Transformer Language Model
3
+
4
+ Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
+
6
+ Everything is written with a modular design for easy modification and experimentation.
7
+
8
+ Key features:
9
+ - RMSNorm for layer normalization
10
+ - Rotary Positional Embeddings (RoPE)
11
+ - Multi-head attention with KV-cache support
12
+ - SwiGLU activation function
13
+ - Residual connections throughout
14
+
15
+ - KV-cache for faster autoregressive generation
16
+
17
+ References:
18
+ - RoPE: https://arxiv.org/abs/2104.09864
19
+ - SwiGLU: https://arxiv.org/abs/2002.05202
20
+ - LLAMA: https://arxiv.org/abs/2302.13971
21
+
22
+ Adapted from:
23
+ - OLMO: https://github.com/allenai/OLMo
24
+ - LLAMA: https://github.com/meta/llama
25
+ """
26
+
27
+ from dataclasses import asdict
28
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ # Handle PyTorch version compatibility for attention backend
35
+ try:
36
+ from torch.nn.attention import SDPBackend, sdpa_kernel
37
+
38
+ HAS_TORCH_ATTENTION = True
39
+ except ImportError:
40
+ # Fallback for older PyTorch versions
41
+ HAS_TORCH_ATTENTION = False
42
+ SDPBackend = None
43
+ sdpa_kernel = None
44
+
45
+ from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
46
+ from transformers.generation import GenerationConfig
47
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
48
+
49
+ try:
50
+ if TYPE_CHECKING:
51
+ # We need to do this to avoid importing these when creating the HF-compatible models
52
+ from src.config import ModelConfig
53
+ except ImportError:
54
+ pass
55
+
56
+ ########################################################
57
+ #
58
+ # Layer Normalization
59
+ #
60
+ ########################################################
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ """Root Mean Square Layer Normalization.
65
+
66
+ A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
67
+ resulting in improved stability and performance.
68
+
69
+ Args:
70
+ config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
71
+ - config.norm_eps: Small constant for numerical stability
72
+ - config.d_model: Model dimension for the weight parameter
73
+
74
+ References:
75
+ https://arxiv.org/abs/1910.07467
76
+ """
77
+
78
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
79
+ super().__init__()
80
+ self.eps = config.norm_eps
81
+ self.weight = nn.Parameter(torch.ones(config.d_model))
82
+
83
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Normalizes the input tensor by its RMS value.
86
+ """
87
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Applies RMS normalization to the input tensor and scales it by the weight parameter.
92
+ """
93
+ output = self._norm(x.float()).type_as(x)
94
+ return output * self.weight
95
+
96
+
97
+ ########################################################
98
+ #
99
+ # Positional Embedding
100
+ #
101
+ ########################################################
102
+
103
+
104
+ class RoPE(nn.Module):
105
+ """Rotary Positional Embeddings (RoPE).
106
+
107
+ Implements position-dependent rotation of keys and queries in attention mechanism,
108
+ allowing better modeling of relative positions in sequences. Uses complex number
109
+ operations for efficient rotation.
110
+
111
+ Args:
112
+ config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
113
+ - config.position_emb_theta: Base for frequency computation
114
+ - config.d_model: Model dimension
115
+ - config.attention_n_heads: Number of attention heads
116
+ - config.max_seq_len: Maximum sequence length
117
+
118
+ References:
119
+ https://arxiv.org/abs/2104.09864
120
+ """
121
+
122
+ _freqs_cis_tensor: torch.Tensor | None = None
123
+
124
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
125
+ super().__init__()
126
+
127
+ self.theta = config.position_emb_theta
128
+ self.dim = config.d_model // config.attention_n_heads
129
+
130
+ max_seq_len = config.max_seq_len
131
+
132
+ # only gets set once, and then reused for all RoPE instances
133
+ if RoPE._freqs_cis_tensor is None:
134
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
135
+ max_seq_len, self.theta, self.dim
136
+ )
137
+
138
+ # register _freqs_cis buffer
139
+ # can be easily recomputed so persistent=False
140
+ self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
141
+
142
+ @classmethod
143
+ def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
144
+ """Setup Frequency Tensor for RoPE Embeddings
145
+
146
+ Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
147
+
148
+ Note other implementations will use cos and sin directly, but using the complex
149
+ number representation is (probably) more efficient:
150
+
151
+ e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
152
+ """
153
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
154
+ positions = torch.arange(seq_len)
155
+ freqs = torch.outer(positions, _freqs)
156
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
157
+
158
+ def get_freqs_cis(
159
+ self, input_shape: torch.Size, start_pos: int, end_pos: int
160
+ ) -> torch.Tensor:
161
+ """Reshape Frequency Tensor for RoPE Embeddings
162
+
163
+ Makes the frequency tensor broadcastable with the input tensor.
164
+ """
165
+ _freqs_cis = self._freqs_cis[start_pos:end_pos]
166
+ ndim = len(input_shape)
167
+ assert 0 <= 1 < ndim
168
+ assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
169
+
170
+ # TODO: Check whether this is correct (might be able to remove this)
171
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
172
+ return _freqs_cis.view(*shape)
173
+
174
+ def forward(
175
+ self,
176
+ queries: torch.Tensor,
177
+ keys: torch.Tensor,
178
+ start_pos: int = 0,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """Apply RoPE Embeddings to Queries and Keys
181
+
182
+ Applies the rotary positional embeddings to the input tensors via complex num multiplication
183
+
184
+ NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
185
+ """
186
+ queries_ = torch.view_as_complex(
187
+ queries.float().reshape(*queries.shape[:-1], -1, 2)
188
+ )
189
+ keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
190
+
191
+ input_shape = (
192
+ queries_.shape
193
+ ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
194
+ freqs_start_pos = start_pos
195
+ freqs_end_pos = freqs_start_pos + queries_.shape[1]
196
+
197
+ freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
198
+
199
+ queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
200
+ keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
201
+ return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
202
+
203
+
204
+ ########################################################
205
+ #
206
+ # Attention
207
+ #
208
+ ########################################################
209
+
210
+
211
+ class Attention(nn.Module):
212
+ """Multi-head Attention with Group Query Attention support.
213
+
214
+ Implements scaled dot-product attention and supports:
215
+ - Grouped Query Attention (GQA)
216
+ - Key-Value caching for efficient inference
217
+ - RoPE integration
218
+
219
+ Args:
220
+ config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
221
+ - config.attention_n_heads: Number of attention heads
222
+ - config.attention_n_kv_heads: Number of key/value heads
223
+ - config.d_model: Model dimension
224
+ - config.batch_size: Maximum batch size
225
+ - config.max_seq_len: Maximum sequence length
226
+
227
+ Shape:
228
+ - Input: (batch_size, seq_len, d_model)
229
+ - Output: (batch_size, seq_len, d_model)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
235
+ ):
236
+ super().__init__()
237
+
238
+ self.n_heads = config.attention_n_heads
239
+ self.n_kv_heads = config.attention_n_kv_heads
240
+
241
+ self.batch_size = config.batch_size
242
+ self.max_seq_len = config.max_seq_len
243
+
244
+ d_model = config.d_model
245
+ self.head_dim = d_model // self.n_heads
246
+
247
+ self.n_rep = self.n_heads // self.n_kv_heads
248
+
249
+ self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
250
+ self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
251
+ self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
252
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
253
+
254
+ self.rope = RoPE(config)
255
+
256
+ def forward(
257
+ self,
258
+ input: torch.Tensor,
259
+ mask: Optional[torch.Tensor] = None,
260
+ past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
261
+ use_cache: bool = False,
262
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
263
+ """Forward pass for the attention mechanism.
264
+
265
+ Computes queries, keys, and values for the attention mechanism. Applies rotary positional
266
+ embeddings to the queries and keys, and then computes attention scores and outputs.
267
+
268
+ For an introduction to the attention mechanism, see:
269
+ https://arxiv.org/abs/1706.03762
270
+
271
+ A few things to note:
272
+ - The past_key_values is used to implement the KV cache, which is used to speed up
273
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
274
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
275
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
276
+ its own KV cache - this KV cache is implemented as a tuple.
277
+ """
278
+ bsz, seq_len, _ = input.shape
279
+ _queries, _keys, _values = (
280
+ self.q_proj(input),
281
+ self.k_proj(input),
282
+ self.v_proj(input),
283
+ )
284
+
285
+ # Reshaping for multi-head attention
286
+ queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
287
+ keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
288
+ values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
289
+
290
+ # The start position is used to apply the RoPE embeddings to only the new tokens
291
+ # when using the kv_cache in the attention mechanism.
292
+ # We want to start from the last position in the cache.
293
+ start_pos = 0
294
+ if past_key_values is not None and past_key_values[0] is not None:
295
+ start_pos = past_key_values[0].shape[1]
296
+
297
+ # apply rotary positional embeddings
298
+ queries, keys = self.rope(queries, keys, start_pos)
299
+
300
+ if (
301
+ past_key_values is not None
302
+ and past_key_values[0] is not None
303
+ and past_key_values[1] is not None
304
+ ):
305
+ keys = torch.cat([past_key_values[0], keys], dim=1)
306
+ values = torch.cat([past_key_values[1], values], dim=1)
307
+
308
+ if use_cache:
309
+ cached_keys = keys
310
+ cached_values = values
311
+ else:
312
+ cached_keys = None
313
+ cached_values = None
314
+
315
+ queries = queries.transpose(1, 2)
316
+ keys = keys.transpose(1, 2)
317
+ values = values.transpose(1, 2)
318
+
319
+ apply_gqa = self.n_rep > 1
320
+ if apply_gqa and queries.device.type == "mps":
321
+ # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
322
+ # outside of the kernel to get the same effect.
323
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
324
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
325
+ values = values.repeat_interleave(self.n_rep, dim=-3)
326
+ apply_gqa = False
327
+
328
+ if HAS_TORCH_ATTENTION:
329
+ backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
330
+ with sdpa_kernel(backends=backends):
331
+ attn_output = F.scaled_dot_product_attention(
332
+ queries.contiguous(),
333
+ keys.contiguous(),
334
+ values.contiguous(),
335
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
336
+ enable_gqa=apply_gqa,
337
+ )
338
+ else:
339
+ # Fallback for older PyTorch versions - use default backend
340
+ attn_output = F.scaled_dot_product_attention(
341
+ queries.contiguous(),
342
+ keys.contiguous(),
343
+ values.contiguous(),
344
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
345
+ enable_gqa=apply_gqa,
346
+ )
347
+
348
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
349
+ output = self.o_proj(attn_output)
350
+
351
+ return output, (cached_keys, cached_values)
352
+
353
+
354
+ ########################################################
355
+ #
356
+ # SwiGLU (Combines MLP and Activation)
357
+ #
358
+ ########################################################
359
+
360
+
361
+ class SwiGLU(nn.Module):
362
+ """SwiGLU Activation Function with Linear Projections.
363
+
364
+ Implements the SwiGLU activation function combined with linear transformations,
365
+ serving as the feed-forward network in transformer blocks.
366
+
367
+ Args:
368
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
369
+ - config.d_model: Model dimension
370
+ - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
371
+
372
+ References:
373
+ https://arxiv.org/abs/2002.05202
374
+ """
375
+
376
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
377
+ super().__init__()
378
+
379
+ model_dim = config.d_model
380
+ act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
381
+
382
+ self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
383
+ self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
384
+ self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
385
+
386
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
388
+
389
+
390
+ ########################################################
391
+ #
392
+ # PicoDecoderBlock
393
+ #
394
+ ########################################################
395
+
396
+
397
+ class PicoDecoderBlock(nn.Module):
398
+ """Single Transformer Block with Attention and Feed-forward layers.
399
+
400
+ Implements a standard transformer block with:
401
+ - Multi-head attention with normalization and residual connection
402
+ - SwiGLU feed-forward network with normalization and residual connection
403
+
404
+ Args:
405
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
406
+ a HuggingFace PicoDecoderHFConfig
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
412
+ ):
413
+ super().__init__()
414
+
415
+ self.attention = Attention(config)
416
+ self.swiglu = SwiGLU(config)
417
+ self.attention_norm = RMSNorm(config)
418
+ self.swiglu_norm = RMSNorm(config)
419
+
420
+ def forward(
421
+ self,
422
+ input: torch.Tensor,
423
+ mask: Optional[torch.Tensor] = None,
424
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
425
+ use_cache: bool = False,
426
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
+ attention_output, cached_key_values = self.attention(
428
+ self.attention_norm(input),
429
+ mask=mask,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ )
433
+ # NOTE: cached_key_values is None if use_cache is False
434
+
435
+ h = input + attention_output
436
+ out = h + self.swiglu(self.swiglu_norm(h))
437
+ return out, cached_key_values
438
+
439
+
440
+ ########################################################
441
+ #
442
+ # Pico Decoder (Causal Transformer Model)
443
+ #
444
+ ########################################################
445
+
446
+
447
+ class PicoDecoder(nn.Module):
448
+ """
449
+ Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
450
+ single autoregressive model.
451
+
452
+ For more information on the model, see the classes for the modules that make up the model.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
458
+ ):
459
+ super().__init__()
460
+ self.config = model_config
461
+
462
+ self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
463
+ self.layers = nn.ModuleList(
464
+ [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
465
+ )
466
+ self.output_norm = RMSNorm(self.config)
467
+ self.de_embedding_proj = nn.Linear(
468
+ self.config.d_model, self.config.vocab_size, bias=False
469
+ )
470
+
471
+ def convert_to_hf_model(self) -> "PicoDecoderHF":
472
+ """Convert the Lightning model to a HuggingFace model."""
473
+ # Create HF config without fabric-specific settings
474
+ hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
475
+
476
+ # Create new HF model
477
+ hf_model = PicoDecoderHF(hf_config)
478
+
479
+ # Copy state dict, excluding fabric-specific keys
480
+ hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
481
+
482
+ return hf_model
483
+
484
+ def forward(
485
+ self,
486
+ input_ids: torch.Tensor,
487
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
488
+ use_cache: bool = False,
489
+ ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
490
+ """
491
+ This is the forward pass for the entire Pico model. It boils down to:
492
+ - Embedding the input ids
493
+ - Creating a causal mask
494
+ - Processing through the pico layers
495
+ - Projecting the output to logits
496
+
497
+ NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
498
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
499
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
500
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
501
+ its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
502
+ KV caches (so a tuple of tuples).
503
+ """
504
+
505
+ seq_len = input_ids.shape[-1]
506
+ h = self.embedding_proj(input_ids)
507
+
508
+ # Calculate start position from past cached KV pairs. Remember that each layer has its
509
+ # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
510
+ # correct layer and then for either the keys or values.
511
+ start_pos = 0
512
+ if (
513
+ past_key_values is not None
514
+ and past_key_values[0] is not None
515
+ and past_key_values[0][0] is not None
516
+ ):
517
+ start_pos = past_key_values[0][0].shape[1]
518
+
519
+ # Create causal mask for current sequence
520
+ mask = None
521
+ if seq_len > 1:
522
+ mask = torch.full((seq_len, seq_len), float("-inf"))
523
+ mask = torch.triu(mask, diagonal=1)
524
+
525
+ # If using KV cache, extend mask to cover cached sequence length
526
+ if past_key_values is not None:
527
+ # Add zeros for cached tokens (we can attend to all of them)
528
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
529
+
530
+ mask = mask.to(h.device)
531
+
532
+ # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
533
+ # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
534
+ cached_key_values = () if use_cache else None
535
+
536
+ # Process through transformer blocks
537
+ for idx, layer in enumerate(self.layers):
538
+ layer_past_key_values = None
539
+ if past_key_values is not None:
540
+ try:
541
+ # Handle both tuple-based cache and HuggingFace cache objects
542
+ if hasattr(past_key_values, "__getitem__") and idx < len(
543
+ past_key_values
544
+ ):
545
+ layer_past_key_values = past_key_values[idx]
546
+ except (KeyError, IndexError, TypeError):
547
+ # If we can't access the cache properly, just skip it
548
+ layer_past_key_values = None
549
+
550
+ h, layer_cached_key_values = layer(
551
+ h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
552
+ )
553
+
554
+ if use_cache:
555
+ cached_key_values += (layer_cached_key_values,)
556
+
557
+ # Final norm and projection
558
+ h = self.output_norm(h)
559
+ logits = self.de_embedding_proj(h).float()
560
+
561
+ return logits, cached_key_values
562
+
563
+
564
+ ########################################################
565
+ #
566
+ # HuggingFace Wrapper for the Pico Decoder model.
567
+ #
568
+ ########################################################
569
+
570
+
571
+ class PicoDecoderHFConfig(PretrainedConfig):
572
+ """Config class for the Pico Decoder HuggingFace wrapper."""
573
+
574
+ model_type = "pico_decoder"
575
+
576
+ @classmethod
577
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
578
+ """
579
+ Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
580
+ this is because with some kwargs special handling is required and can make this class
581
+ brittle.
582
+ """
583
+ pico_config = cls(**config_dict)
584
+
585
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
+ unused_kwargs = {
587
+ key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
588
+ }
589
+
590
+ if return_unused_kwargs:
591
+ return pico_config, unused_kwargs
592
+ return pico_config
593
+
594
+ @classmethod
595
+ def from_dataclass(cls, model_config: "ModelConfig"):
596
+ """Initialise from our custom config dataclass."""
597
+ return cls.from_dict(asdict(model_config))
598
+
599
+
600
+ class PicoDecoderHF(PreTrainedModel, GenerationMixin):
601
+ """
602
+ HuggingFace wrapper for the Pico model with generation support.
603
+
604
+ Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
605
+ wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
606
+ Pico model as well as the model wrapped in this HuggingFace class.
607
+
608
+ This also lets you do cool things like:
609
+
610
+ `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
611
+ """
612
+
613
+ config_class = PicoDecoderHFConfig
614
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
615
+ main_input_name = "input_ids"
616
+
617
+ def __init__(self, config: PicoDecoderHFConfig):
618
+ super().__init__(config)
619
+ self.pico_decoder = PicoDecoder(config)
620
+ # Initialize generation config with defaults
621
+ self.generation_config = GenerationConfig()
622
+ # Set some reasonable defaults for the model
623
+ if hasattr(config, "max_position_embeddings"):
624
+ self.generation_config.max_length = config.max_position_embeddings
625
+ if hasattr(config, "vocab_size"):
626
+ self.generation_config.vocab_size = config.vocab_size
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: torch.Tensor,
631
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
632
+ use_cache: bool = False,
633
+ **kwargs,
634
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
635
+ """HuggingFace forward pass wrapper.
636
+
637
+ Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
638
+ Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
639
+ """
640
+ logits, past_key_values = self.pico_decoder(
641
+ input_ids, past_key_values, use_cache
642
+ )
643
+ if use_cache:
644
+ return CausalLMOutputWithPast(
645
+ logits=logits,
646
+ past_key_values=past_key_values,
647
+ )
648
+ else:
649
+ return CausalLMOutput(
650
+ logits=logits,
651
+ )
652
+
653
+ def prepare_inputs_for_generation(
654
+ self,
655
+ input_ids: torch.LongTensor,
656
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
657
+ attention_mask: Optional[torch.LongTensor] = None,
658
+ **kwargs,
659
+ ) -> Dict[str, Any]:
660
+ """
661
+ Prepare inputs for generation.
662
+
663
+ Args:
664
+ input_ids: Input token IDs
665
+ past_key_values: Cached key-value pairs from previous forward passes
666
+ attention_mask: Attention mask for the input
667
+ **kwargs: Additional arguments
668
+
669
+ Returns:
670
+ Dictionary containing prepared inputs
671
+ """
672
+ # If we have past_key_values, we only need the last token
673
+ if past_key_values is not None:
674
+ input_ids = input_ids[:, -1:]
675
+
676
+ return {
677
+ "input_ids": input_ids,
678
+ "past_key_values": past_key_values,
679
+ "use_cache": True,
680
+ }
681
+
682
+ def get_input_embeddings(self):
683
+ """Get the input embeddings layer."""
684
+ return self.pico_decoder.embedding_proj
685
+
686
+ def set_input_embeddings(self, value):
687
+ """Set the input embeddings layer."""
688
+ self.pico_decoder.embedding_proj = value
689
+
690
+ def get_output_embeddings(self):
691
+ """Get the output embeddings layer."""
692
+ return self.pico_decoder.de_embedding_proj
693
+
694
+ def set_output_embeddings(self, value):
695
+ """Set the output embeddings layer."""
696
+ self.pico_decoder.de_embedding_proj = value
697
+
698
+ def get_lm_head(self):
699
+ """Get the language model head."""
700
+ return self.pico_decoder.de_embedding_proj
701
+
702
+ def can_generate(self) -> bool:
703
+ """Check if the model can generate text."""
704
+ return True
705
+
706
+ @property
707
+ def is_encoder_decoder(self) -> bool:
708
+ """Check if the model is an encoder-decoder model."""
709
+ return False
710
+
711
+ @property
712
+ def can_use_cache(self) -> bool:
713
+ """Check if the model can use KV cache."""
714
+ return True
715
+
716
+ def resize_token_embeddings(
717
+ self, new_num_tokens: Optional[int] = None
718
+ ) -> torch.nn.Embedding:
719
+ """Resize token embeddings."""
720
+ old_embeddings = self.get_input_embeddings()
721
+ if new_num_tokens is None:
722
+ new_num_tokens = old_embeddings.num_embeddings
723
+
724
+ new_embeddings = torch.nn.Embedding(
725
+ new_num_tokens, old_embeddings.embedding_dim
726
+ )
727
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
728
+ old_embeddings.weight.data
729
+ )
730
+
731
+ self.pico_decoder.embedding_proj = new_embeddings
732
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
733
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
734
+ )
735
+
736
+ return new_embeddings
737
+
738
+
739
+ # Register for auto classes
740
+ PicoDecoderHFConfig.register_for_auto_class()
741
+ PicoDecoderHF.register_for_auto_class("AutoModel")
742
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
743
+
744
+
745
+ ########################################################
746
+ #
747
+ # New PicoDecoderForCausalLM class for generation support
748
+ #
749
+ ########################################################
750
+
751
+
752
+ class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
753
+ """
754
+ PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
755
+
756
+ This class is designed to work with existing checkpoints and provides full generation support.
757
+ It inherits from the right base classes that HuggingFace expects for text generation.
758
+ """
759
+
760
+ config_class = PicoDecoderHFConfig
761
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
762
+ main_input_name = "input_ids"
763
+
764
+ def __init__(self, config: PicoDecoderHFConfig):
765
+ super().__init__(config)
766
+ self.pico_decoder = PicoDecoder(config)
767
+ # Initialize generation config with defaults
768
+ self.generation_config = GenerationConfig()
769
+ # Set some reasonable defaults for the model
770
+ if hasattr(config, "max_position_embeddings"):
771
+ self.generation_config.max_length = config.max_position_embeddings
772
+ if hasattr(config, "vocab_size"):
773
+ self.generation_config.vocab_size = config.vocab_size
774
+
775
+ def forward(
776
+ self,
777
+ input_ids: torch.Tensor,
778
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
779
+ use_cache: bool = False,
780
+ **kwargs,
781
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
782
+ """Forward pass for text generation."""
783
+ logits, past_key_values = self.pico_decoder(
784
+ input_ids, past_key_values, use_cache
785
+ )
786
+ if use_cache:
787
+ return CausalLMOutputWithPast(
788
+ logits=logits,
789
+ past_key_values=past_key_values,
790
+ )
791
+ else:
792
+ return CausalLMOutput(
793
+ logits=logits,
794
+ )
795
+
796
+ def prepare_inputs_for_generation(
797
+ self,
798
+ input_ids: torch.LongTensor,
799
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
800
+ attention_mask: Optional[torch.LongTensor] = None,
801
+ **kwargs,
802
+ ) -> Dict[str, Any]:
803
+ """Prepare inputs for generation."""
804
+ # If we have past_key_values, we only need the last token
805
+ if past_key_values is not None:
806
+ input_ids = input_ids[:, -1:]
807
+
808
+ return {
809
+ "input_ids": input_ids,
810
+ "past_key_values": past_key_values,
811
+ "use_cache": True,
812
+ }
813
+
814
+ def get_input_embeddings(self):
815
+ """Get the input embeddings layer."""
816
+ return self.pico_decoder.embedding_proj
817
+
818
+ def set_input_embeddings(self, value):
819
+ """Set the input embeddings layer."""
820
+ self.pico_decoder.embedding_proj = value
821
+
822
+ def get_output_embeddings(self):
823
+ """Get the output embeddings layer."""
824
+ return self.pico_decoder.de_embedding_proj
825
+
826
+ def set_output_embeddings(self, value):
827
+ """Set the output embeddings layer."""
828
+ self.pico_decoder.de_embedding_proj = value
829
+
830
+ def get_lm_head(self):
831
+ """Get the language model head."""
832
+ return self.pico_decoder.de_embedding_proj
833
+
834
+ def can_generate(self) -> bool:
835
+ """Check if the model can generate text."""
836
+ return True
837
+
838
+ @property
839
+ def is_encoder_decoder(self) -> bool:
840
+ """Check if the model is an encoder-decoder model."""
841
+ return False
842
+
843
+ @property
844
+ def can_use_cache(self) -> bool:
845
+ """Check if the model can use KV cache."""
846
+ return True
847
+
848
+ def resize_token_embeddings(
849
+ self, new_num_tokens: Optional[int] = None
850
+ ) -> torch.nn.Embedding:
851
+ """Resize token embeddings."""
852
+ old_embeddings = self.get_input_embeddings()
853
+ if new_num_tokens is None:
854
+ new_num_tokens = old_embeddings.num_embeddings
855
+
856
+ new_embeddings = torch.nn.Embedding(
857
+ new_num_tokens, old_embeddings.embedding_dim
858
+ )
859
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
860
+ old_embeddings.weight.data
861
+ )
862
+
863
+ self.pico_decoder.embedding_proj = new_embeddings
864
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
865
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
866
+ )
867
+
868
+ return new_embeddings
869
+
870
+ @classmethod
871
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
872
+ """
873
+ Load a pretrained model from a checkpoint.
874
+
875
+ This method handles loading from both the old PicoDecoderHF format and the new format.
876
+ """
877
+ # First try to load with the new class
878
+ try:
879
+ return super().from_pretrained(
880
+ pretrained_model_name_or_path, *model_args, **kwargs
881
+ )
882
+ except Exception as e:
883
+ print(f"Failed to load with new class: {e}")
884
+ print("Attempting to load with legacy class and convert...")
885
+
886
+ # Try to load with the old class and convert
887
+ try:
888
+ from transformers import AutoModel
889
+
890
+ old_model = AutoModel.from_pretrained(
891
+ pretrained_model_name_or_path,
892
+ trust_remote_code=True,
893
+ *model_args,
894
+ **kwargs,
895
+ )
896
+
897
+ # Create new model instance
898
+ new_model = cls(old_model.config)
899
+
900
+ # Copy state dict
901
+ new_model.load_state_dict(old_model.state_dict(), strict=False)
902
+
903
+ return new_model
904
+
905
+ except Exception as e2:
906
+ print(f"Failed to convert from legacy format: {e2}")
907
+ raise e
908
+
909
+
910
+ # Register the new class
911
+ PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma250M-v1/checkpoints/step_102000/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }
pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_hidden_dim": 384,
3
+ "architectures": [
4
+ "PicoDecoderHF"
5
+ ],
6
+ "attention_n_heads": 12,
7
+ "attention_n_kv_heads": 4,
8
+ "auto_map": {
9
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
+ },
12
+ "batch_size": 1024,
13
+ "d_model": 96,
14
+ "max_seq_len": 2048,
15
+ "model_type": "pico_decoder",
16
+ "n_layers": 12,
17
+ "norm_eps": 1e-06,
18
+ "position_emb_theta": 10000.0,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.48.3",
21
+ "vocab_size": 50304
22
+ }
pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/fabric_state/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f39d8040ef9e3d9b47f2a819749ffc0c39ad0a6aa2c965cd8d6106726b2e55d6
3
+ size 135543171
pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "transformers_version": "4.48.3",
3
+ "vocab_size": 50304
4
+ }
pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d2c1947251972d2ba272b673761ef16d72a8f30e3967f821245685a51c8347c
3
+ size 45143592
pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/pico_decoder.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Decoder: A Lightweight Causal Transformer Language Model
3
+
4
+ Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
+
6
+ Everything is written with a modular design for easy modification and experimentation.
7
+
8
+ Key features:
9
+ - RMSNorm for layer normalization
10
+ - Rotary Positional Embeddings (RoPE)
11
+ - Multi-head attention with KV-cache support
12
+ - SwiGLU activation function
13
+ - Residual connections throughout
14
+
15
+ - KV-cache for faster autoregressive generation
16
+
17
+ References:
18
+ - RoPE: https://arxiv.org/abs/2104.09864
19
+ - SwiGLU: https://arxiv.org/abs/2002.05202
20
+ - LLAMA: https://arxiv.org/abs/2302.13971
21
+
22
+ Adapted from:
23
+ - OLMO: https://github.com/allenai/OLMo
24
+ - LLAMA: https://github.com/meta/llama
25
+ """
26
+
27
+ from dataclasses import asdict
28
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ # Handle PyTorch version compatibility for attention backend
35
+ try:
36
+ from torch.nn.attention import SDPBackend, sdpa_kernel
37
+
38
+ HAS_TORCH_ATTENTION = True
39
+ except ImportError:
40
+ # Fallback for older PyTorch versions
41
+ HAS_TORCH_ATTENTION = False
42
+ SDPBackend = None
43
+ sdpa_kernel = None
44
+
45
+ from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
46
+ from transformers.generation import GenerationConfig
47
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
48
+
49
+ try:
50
+ if TYPE_CHECKING:
51
+ # We need to do this to avoid importing these when creating the HF-compatible models
52
+ from src.config import ModelConfig
53
+ except ImportError:
54
+ pass
55
+
56
+ ########################################################
57
+ #
58
+ # Layer Normalization
59
+ #
60
+ ########################################################
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ """Root Mean Square Layer Normalization.
65
+
66
+ A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
67
+ resulting in improved stability and performance.
68
+
69
+ Args:
70
+ config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
71
+ - config.norm_eps: Small constant for numerical stability
72
+ - config.d_model: Model dimension for the weight parameter
73
+
74
+ References:
75
+ https://arxiv.org/abs/1910.07467
76
+ """
77
+
78
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
79
+ super().__init__()
80
+ self.eps = config.norm_eps
81
+ self.weight = nn.Parameter(torch.ones(config.d_model))
82
+
83
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Normalizes the input tensor by its RMS value.
86
+ """
87
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Applies RMS normalization to the input tensor and scales it by the weight parameter.
92
+ """
93
+ output = self._norm(x.float()).type_as(x)
94
+ return output * self.weight
95
+
96
+
97
+ ########################################################
98
+ #
99
+ # Positional Embedding
100
+ #
101
+ ########################################################
102
+
103
+
104
+ class RoPE(nn.Module):
105
+ """Rotary Positional Embeddings (RoPE).
106
+
107
+ Implements position-dependent rotation of keys and queries in attention mechanism,
108
+ allowing better modeling of relative positions in sequences. Uses complex number
109
+ operations for efficient rotation.
110
+
111
+ Args:
112
+ config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
113
+ - config.position_emb_theta: Base for frequency computation
114
+ - config.d_model: Model dimension
115
+ - config.attention_n_heads: Number of attention heads
116
+ - config.max_seq_len: Maximum sequence length
117
+
118
+ References:
119
+ https://arxiv.org/abs/2104.09864
120
+ """
121
+
122
+ _freqs_cis_tensor: torch.Tensor | None = None
123
+
124
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
125
+ super().__init__()
126
+
127
+ self.theta = config.position_emb_theta
128
+ self.dim = config.d_model // config.attention_n_heads
129
+
130
+ max_seq_len = config.max_seq_len
131
+
132
+ # only gets set once, and then reused for all RoPE instances
133
+ if RoPE._freqs_cis_tensor is None:
134
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
135
+ max_seq_len, self.theta, self.dim
136
+ )
137
+
138
+ # register _freqs_cis buffer
139
+ # can be easily recomputed so persistent=False
140
+ self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
141
+
142
+ @classmethod
143
+ def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
144
+ """Setup Frequency Tensor for RoPE Embeddings
145
+
146
+ Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
147
+
148
+ Note other implementations will use cos and sin directly, but using the complex
149
+ number representation is (probably) more efficient:
150
+
151
+ e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
152
+ """
153
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
154
+ positions = torch.arange(seq_len)
155
+ freqs = torch.outer(positions, _freqs)
156
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
157
+
158
+ def get_freqs_cis(
159
+ self, input_shape: torch.Size, start_pos: int, end_pos: int
160
+ ) -> torch.Tensor:
161
+ """Reshape Frequency Tensor for RoPE Embeddings
162
+
163
+ Makes the frequency tensor broadcastable with the input tensor.
164
+ """
165
+ _freqs_cis = self._freqs_cis[start_pos:end_pos]
166
+ ndim = len(input_shape)
167
+ assert 0 <= 1 < ndim
168
+ assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
169
+
170
+ # TODO: Check whether this is correct (might be able to remove this)
171
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
172
+ return _freqs_cis.view(*shape)
173
+
174
+ def forward(
175
+ self,
176
+ queries: torch.Tensor,
177
+ keys: torch.Tensor,
178
+ start_pos: int = 0,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """Apply RoPE Embeddings to Queries and Keys
181
+
182
+ Applies the rotary positional embeddings to the input tensors via complex num multiplication
183
+
184
+ NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
185
+ """
186
+ queries_ = torch.view_as_complex(
187
+ queries.float().reshape(*queries.shape[:-1], -1, 2)
188
+ )
189
+ keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
190
+
191
+ input_shape = (
192
+ queries_.shape
193
+ ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
194
+ freqs_start_pos = start_pos
195
+ freqs_end_pos = freqs_start_pos + queries_.shape[1]
196
+
197
+ freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
198
+
199
+ queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
200
+ keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
201
+ return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
202
+
203
+
204
+ ########################################################
205
+ #
206
+ # Attention
207
+ #
208
+ ########################################################
209
+
210
+
211
+ class Attention(nn.Module):
212
+ """Multi-head Attention with Group Query Attention support.
213
+
214
+ Implements scaled dot-product attention and supports:
215
+ - Grouped Query Attention (GQA)
216
+ - Key-Value caching for efficient inference
217
+ - RoPE integration
218
+
219
+ Args:
220
+ config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
221
+ - config.attention_n_heads: Number of attention heads
222
+ - config.attention_n_kv_heads: Number of key/value heads
223
+ - config.d_model: Model dimension
224
+ - config.batch_size: Maximum batch size
225
+ - config.max_seq_len: Maximum sequence length
226
+
227
+ Shape:
228
+ - Input: (batch_size, seq_len, d_model)
229
+ - Output: (batch_size, seq_len, d_model)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
235
+ ):
236
+ super().__init__()
237
+
238
+ self.n_heads = config.attention_n_heads
239
+ self.n_kv_heads = config.attention_n_kv_heads
240
+
241
+ self.batch_size = config.batch_size
242
+ self.max_seq_len = config.max_seq_len
243
+
244
+ d_model = config.d_model
245
+ self.head_dim = d_model // self.n_heads
246
+
247
+ self.n_rep = self.n_heads // self.n_kv_heads
248
+
249
+ self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
250
+ self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
251
+ self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
252
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
253
+
254
+ self.rope = RoPE(config)
255
+
256
+ def forward(
257
+ self,
258
+ input: torch.Tensor,
259
+ mask: Optional[torch.Tensor] = None,
260
+ past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
261
+ use_cache: bool = False,
262
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
263
+ """Forward pass for the attention mechanism.
264
+
265
+ Computes queries, keys, and values for the attention mechanism. Applies rotary positional
266
+ embeddings to the queries and keys, and then computes attention scores and outputs.
267
+
268
+ For an introduction to the attention mechanism, see:
269
+ https://arxiv.org/abs/1706.03762
270
+
271
+ A few things to note:
272
+ - The past_key_values is used to implement the KV cache, which is used to speed up
273
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
274
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
275
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
276
+ its own KV cache - this KV cache is implemented as a tuple.
277
+ """
278
+ bsz, seq_len, _ = input.shape
279
+ _queries, _keys, _values = (
280
+ self.q_proj(input),
281
+ self.k_proj(input),
282
+ self.v_proj(input),
283
+ )
284
+
285
+ # Reshaping for multi-head attention
286
+ queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
287
+ keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
288
+ values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
289
+
290
+ # The start position is used to apply the RoPE embeddings to only the new tokens
291
+ # when using the kv_cache in the attention mechanism.
292
+ # We want to start from the last position in the cache.
293
+ start_pos = 0
294
+ if past_key_values is not None and past_key_values[0] is not None:
295
+ start_pos = past_key_values[0].shape[1]
296
+
297
+ # apply rotary positional embeddings
298
+ queries, keys = self.rope(queries, keys, start_pos)
299
+
300
+ if (
301
+ past_key_values is not None
302
+ and past_key_values[0] is not None
303
+ and past_key_values[1] is not None
304
+ ):
305
+ keys = torch.cat([past_key_values[0], keys], dim=1)
306
+ values = torch.cat([past_key_values[1], values], dim=1)
307
+
308
+ if use_cache:
309
+ cached_keys = keys
310
+ cached_values = values
311
+ else:
312
+ cached_keys = None
313
+ cached_values = None
314
+
315
+ queries = queries.transpose(1, 2)
316
+ keys = keys.transpose(1, 2)
317
+ values = values.transpose(1, 2)
318
+
319
+ apply_gqa = self.n_rep > 1
320
+ if apply_gqa and queries.device.type == "mps":
321
+ # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
322
+ # outside of the kernel to get the same effect.
323
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
324
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
325
+ values = values.repeat_interleave(self.n_rep, dim=-3)
326
+ apply_gqa = False
327
+
328
+ if HAS_TORCH_ATTENTION:
329
+ backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
330
+ with sdpa_kernel(backends=backends):
331
+ attn_output = F.scaled_dot_product_attention(
332
+ queries.contiguous(),
333
+ keys.contiguous(),
334
+ values.contiguous(),
335
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
336
+ enable_gqa=apply_gqa,
337
+ )
338
+ else:
339
+ # Fallback for older PyTorch versions - use default backend
340
+ attn_output = F.scaled_dot_product_attention(
341
+ queries.contiguous(),
342
+ keys.contiguous(),
343
+ values.contiguous(),
344
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
345
+ enable_gqa=apply_gqa,
346
+ )
347
+
348
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
349
+ output = self.o_proj(attn_output)
350
+
351
+ return output, (cached_keys, cached_values)
352
+
353
+
354
+ ########################################################
355
+ #
356
+ # SwiGLU (Combines MLP and Activation)
357
+ #
358
+ ########################################################
359
+
360
+
361
+ class SwiGLU(nn.Module):
362
+ """SwiGLU Activation Function with Linear Projections.
363
+
364
+ Implements the SwiGLU activation function combined with linear transformations,
365
+ serving as the feed-forward network in transformer blocks.
366
+
367
+ Args:
368
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
369
+ - config.d_model: Model dimension
370
+ - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
371
+
372
+ References:
373
+ https://arxiv.org/abs/2002.05202
374
+ """
375
+
376
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
377
+ super().__init__()
378
+
379
+ model_dim = config.d_model
380
+ act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
381
+
382
+ self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
383
+ self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
384
+ self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
385
+
386
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
388
+
389
+
390
+ ########################################################
391
+ #
392
+ # PicoDecoderBlock
393
+ #
394
+ ########################################################
395
+
396
+
397
+ class PicoDecoderBlock(nn.Module):
398
+ """Single Transformer Block with Attention and Feed-forward layers.
399
+
400
+ Implements a standard transformer block with:
401
+ - Multi-head attention with normalization and residual connection
402
+ - SwiGLU feed-forward network with normalization and residual connection
403
+
404
+ Args:
405
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
406
+ a HuggingFace PicoDecoderHFConfig
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
412
+ ):
413
+ super().__init__()
414
+
415
+ self.attention = Attention(config)
416
+ self.swiglu = SwiGLU(config)
417
+ self.attention_norm = RMSNorm(config)
418
+ self.swiglu_norm = RMSNorm(config)
419
+
420
+ def forward(
421
+ self,
422
+ input: torch.Tensor,
423
+ mask: Optional[torch.Tensor] = None,
424
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
425
+ use_cache: bool = False,
426
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
+ attention_output, cached_key_values = self.attention(
428
+ self.attention_norm(input),
429
+ mask=mask,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ )
433
+ # NOTE: cached_key_values is None if use_cache is False
434
+
435
+ h = input + attention_output
436
+ out = h + self.swiglu(self.swiglu_norm(h))
437
+ return out, cached_key_values
438
+
439
+
440
+ ########################################################
441
+ #
442
+ # Pico Decoder (Causal Transformer Model)
443
+ #
444
+ ########################################################
445
+
446
+
447
+ class PicoDecoder(nn.Module):
448
+ """
449
+ Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
450
+ single autoregressive model.
451
+
452
+ For more information on the model, see the classes for the modules that make up the model.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
458
+ ):
459
+ super().__init__()
460
+ self.config = model_config
461
+
462
+ self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
463
+ self.layers = nn.ModuleList(
464
+ [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
465
+ )
466
+ self.output_norm = RMSNorm(self.config)
467
+ self.de_embedding_proj = nn.Linear(
468
+ self.config.d_model, self.config.vocab_size, bias=False
469
+ )
470
+
471
+ def convert_to_hf_model(self) -> "PicoDecoderHF":
472
+ """Convert the Lightning model to a HuggingFace model."""
473
+ # Create HF config without fabric-specific settings
474
+ hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
475
+
476
+ # Create new HF model
477
+ hf_model = PicoDecoderHF(hf_config)
478
+
479
+ # Copy state dict, excluding fabric-specific keys
480
+ hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
481
+
482
+ return hf_model
483
+
484
+ def forward(
485
+ self,
486
+ input_ids: torch.Tensor,
487
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
488
+ use_cache: bool = False,
489
+ ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
490
+ """
491
+ This is the forward pass for the entire Pico model. It boils down to:
492
+ - Embedding the input ids
493
+ - Creating a causal mask
494
+ - Processing through the pico layers
495
+ - Projecting the output to logits
496
+
497
+ NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
498
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
499
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
500
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
501
+ its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
502
+ KV caches (so a tuple of tuples).
503
+ """
504
+
505
+ seq_len = input_ids.shape[-1]
506
+ h = self.embedding_proj(input_ids)
507
+
508
+ # Calculate start position from past cached KV pairs. Remember that each layer has its
509
+ # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
510
+ # correct layer and then for either the keys or values.
511
+ start_pos = 0
512
+ if (
513
+ past_key_values is not None
514
+ and past_key_values[0] is not None
515
+ and past_key_values[0][0] is not None
516
+ ):
517
+ start_pos = past_key_values[0][0].shape[1]
518
+
519
+ # Create causal mask for current sequence
520
+ mask = None
521
+ if seq_len > 1:
522
+ mask = torch.full((seq_len, seq_len), float("-inf"))
523
+ mask = torch.triu(mask, diagonal=1)
524
+
525
+ # If using KV cache, extend mask to cover cached sequence length
526
+ if past_key_values is not None:
527
+ # Add zeros for cached tokens (we can attend to all of them)
528
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
529
+
530
+ mask = mask.to(h.device)
531
+
532
+ # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
533
+ # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
534
+ cached_key_values = () if use_cache else None
535
+
536
+ # Process through transformer blocks
537
+ for idx, layer in enumerate(self.layers):
538
+ layer_past_key_values = None
539
+ if past_key_values is not None:
540
+ try:
541
+ # Handle both tuple-based cache and HuggingFace cache objects
542
+ if hasattr(past_key_values, "__getitem__") and idx < len(
543
+ past_key_values
544
+ ):
545
+ layer_past_key_values = past_key_values[idx]
546
+ except (KeyError, IndexError, TypeError):
547
+ # If we can't access the cache properly, just skip it
548
+ layer_past_key_values = None
549
+
550
+ h, layer_cached_key_values = layer(
551
+ h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
552
+ )
553
+
554
+ if use_cache:
555
+ cached_key_values += (layer_cached_key_values,)
556
+
557
+ # Final norm and projection
558
+ h = self.output_norm(h)
559
+ logits = self.de_embedding_proj(h).float()
560
+
561
+ return logits, cached_key_values
562
+
563
+
564
+ ########################################################
565
+ #
566
+ # HuggingFace Wrapper for the Pico Decoder model.
567
+ #
568
+ ########################################################
569
+
570
+
571
+ class PicoDecoderHFConfig(PretrainedConfig):
572
+ """Config class for the Pico Decoder HuggingFace wrapper."""
573
+
574
+ model_type = "pico_decoder"
575
+
576
+ @classmethod
577
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
578
+ """
579
+ Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
580
+ this is because with some kwargs special handling is required and can make this class
581
+ brittle.
582
+ """
583
+ pico_config = cls(**config_dict)
584
+
585
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
+ unused_kwargs = {
587
+ key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
588
+ }
589
+
590
+ if return_unused_kwargs:
591
+ return pico_config, unused_kwargs
592
+ return pico_config
593
+
594
+ @classmethod
595
+ def from_dataclass(cls, model_config: "ModelConfig"):
596
+ """Initialise from our custom config dataclass."""
597
+ return cls.from_dict(asdict(model_config))
598
+
599
+
600
+ class PicoDecoderHF(PreTrainedModel, GenerationMixin):
601
+ """
602
+ HuggingFace wrapper for the Pico model with generation support.
603
+
604
+ Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
605
+ wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
606
+ Pico model as well as the model wrapped in this HuggingFace class.
607
+
608
+ This also lets you do cool things like:
609
+
610
+ `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
611
+ """
612
+
613
+ config_class = PicoDecoderHFConfig
614
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
615
+ main_input_name = "input_ids"
616
+
617
+ def __init__(self, config: PicoDecoderHFConfig):
618
+ super().__init__(config)
619
+ self.pico_decoder = PicoDecoder(config)
620
+ # Initialize generation config with defaults
621
+ self.generation_config = GenerationConfig()
622
+ # Set some reasonable defaults for the model
623
+ if hasattr(config, "max_position_embeddings"):
624
+ self.generation_config.max_length = config.max_position_embeddings
625
+ if hasattr(config, "vocab_size"):
626
+ self.generation_config.vocab_size = config.vocab_size
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: torch.Tensor,
631
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
632
+ use_cache: bool = False,
633
+ **kwargs,
634
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
635
+ """HuggingFace forward pass wrapper.
636
+
637
+ Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
638
+ Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
639
+ """
640
+ logits, past_key_values = self.pico_decoder(
641
+ input_ids, past_key_values, use_cache
642
+ )
643
+ if use_cache:
644
+ return CausalLMOutputWithPast(
645
+ logits=logits,
646
+ past_key_values=past_key_values,
647
+ )
648
+ else:
649
+ return CausalLMOutput(
650
+ logits=logits,
651
+ )
652
+
653
+ def prepare_inputs_for_generation(
654
+ self,
655
+ input_ids: torch.LongTensor,
656
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
657
+ attention_mask: Optional[torch.LongTensor] = None,
658
+ **kwargs,
659
+ ) -> Dict[str, Any]:
660
+ """
661
+ Prepare inputs for generation.
662
+
663
+ Args:
664
+ input_ids: Input token IDs
665
+ past_key_values: Cached key-value pairs from previous forward passes
666
+ attention_mask: Attention mask for the input
667
+ **kwargs: Additional arguments
668
+
669
+ Returns:
670
+ Dictionary containing prepared inputs
671
+ """
672
+ # If we have past_key_values, we only need the last token
673
+ if past_key_values is not None:
674
+ input_ids = input_ids[:, -1:]
675
+
676
+ return {
677
+ "input_ids": input_ids,
678
+ "past_key_values": past_key_values,
679
+ "use_cache": True,
680
+ }
681
+
682
+ def get_input_embeddings(self):
683
+ """Get the input embeddings layer."""
684
+ return self.pico_decoder.embedding_proj
685
+
686
+ def set_input_embeddings(self, value):
687
+ """Set the input embeddings layer."""
688
+ self.pico_decoder.embedding_proj = value
689
+
690
+ def get_output_embeddings(self):
691
+ """Get the output embeddings layer."""
692
+ return self.pico_decoder.de_embedding_proj
693
+
694
+ def set_output_embeddings(self, value):
695
+ """Set the output embeddings layer."""
696
+ self.pico_decoder.de_embedding_proj = value
697
+
698
+ def get_lm_head(self):
699
+ """Get the language model head."""
700
+ return self.pico_decoder.de_embedding_proj
701
+
702
+ def can_generate(self) -> bool:
703
+ """Check if the model can generate text."""
704
+ return True
705
+
706
+ @property
707
+ def is_encoder_decoder(self) -> bool:
708
+ """Check if the model is an encoder-decoder model."""
709
+ return False
710
+
711
+ @property
712
+ def can_use_cache(self) -> bool:
713
+ """Check if the model can use KV cache."""
714
+ return True
715
+
716
+ def resize_token_embeddings(
717
+ self, new_num_tokens: Optional[int] = None
718
+ ) -> torch.nn.Embedding:
719
+ """Resize token embeddings."""
720
+ old_embeddings = self.get_input_embeddings()
721
+ if new_num_tokens is None:
722
+ new_num_tokens = old_embeddings.num_embeddings
723
+
724
+ new_embeddings = torch.nn.Embedding(
725
+ new_num_tokens, old_embeddings.embedding_dim
726
+ )
727
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
728
+ old_embeddings.weight.data
729
+ )
730
+
731
+ self.pico_decoder.embedding_proj = new_embeddings
732
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
733
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
734
+ )
735
+
736
+ return new_embeddings
737
+
738
+
739
+ # Register for auto classes
740
+ PicoDecoderHFConfig.register_for_auto_class()
741
+ PicoDecoderHF.register_for_auto_class("AutoModel")
742
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
743
+
744
+
745
+ ########################################################
746
+ #
747
+ # New PicoDecoderForCausalLM class for generation support
748
+ #
749
+ ########################################################
750
+
751
+
752
+ class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
753
+ """
754
+ PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
755
+
756
+ This class is designed to work with existing checkpoints and provides full generation support.
757
+ It inherits from the right base classes that HuggingFace expects for text generation.
758
+ """
759
+
760
+ config_class = PicoDecoderHFConfig
761
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
762
+ main_input_name = "input_ids"
763
+
764
+ def __init__(self, config: PicoDecoderHFConfig):
765
+ super().__init__(config)
766
+ self.pico_decoder = PicoDecoder(config)
767
+ # Initialize generation config with defaults
768
+ self.generation_config = GenerationConfig()
769
+ # Set some reasonable defaults for the model
770
+ if hasattr(config, "max_position_embeddings"):
771
+ self.generation_config.max_length = config.max_position_embeddings
772
+ if hasattr(config, "vocab_size"):
773
+ self.generation_config.vocab_size = config.vocab_size
774
+
775
+ def forward(
776
+ self,
777
+ input_ids: torch.Tensor,
778
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
779
+ use_cache: bool = False,
780
+ **kwargs,
781
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
782
+ """Forward pass for text generation."""
783
+ logits, past_key_values = self.pico_decoder(
784
+ input_ids, past_key_values, use_cache
785
+ )
786
+ if use_cache:
787
+ return CausalLMOutputWithPast(
788
+ logits=logits,
789
+ past_key_values=past_key_values,
790
+ )
791
+ else:
792
+ return CausalLMOutput(
793
+ logits=logits,
794
+ )
795
+
796
+ def prepare_inputs_for_generation(
797
+ self,
798
+ input_ids: torch.LongTensor,
799
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
800
+ attention_mask: Optional[torch.LongTensor] = None,
801
+ **kwargs,
802
+ ) -> Dict[str, Any]:
803
+ """Prepare inputs for generation."""
804
+ # If we have past_key_values, we only need the last token
805
+ if past_key_values is not None:
806
+ input_ids = input_ids[:, -1:]
807
+
808
+ return {
809
+ "input_ids": input_ids,
810
+ "past_key_values": past_key_values,
811
+ "use_cache": True,
812
+ }
813
+
814
+ def get_input_embeddings(self):
815
+ """Get the input embeddings layer."""
816
+ return self.pico_decoder.embedding_proj
817
+
818
+ def set_input_embeddings(self, value):
819
+ """Set the input embeddings layer."""
820
+ self.pico_decoder.embedding_proj = value
821
+
822
+ def get_output_embeddings(self):
823
+ """Get the output embeddings layer."""
824
+ return self.pico_decoder.de_embedding_proj
825
+
826
+ def set_output_embeddings(self, value):
827
+ """Set the output embeddings layer."""
828
+ self.pico_decoder.de_embedding_proj = value
829
+
830
+ def get_lm_head(self):
831
+ """Get the language model head."""
832
+ return self.pico_decoder.de_embedding_proj
833
+
834
+ def can_generate(self) -> bool:
835
+ """Check if the model can generate text."""
836
+ return True
837
+
838
+ @property
839
+ def is_encoder_decoder(self) -> bool:
840
+ """Check if the model is an encoder-decoder model."""
841
+ return False
842
+
843
+ @property
844
+ def can_use_cache(self) -> bool:
845
+ """Check if the model can use KV cache."""
846
+ return True
847
+
848
+ def resize_token_embeddings(
849
+ self, new_num_tokens: Optional[int] = None
850
+ ) -> torch.nn.Embedding:
851
+ """Resize token embeddings."""
852
+ old_embeddings = self.get_input_embeddings()
853
+ if new_num_tokens is None:
854
+ new_num_tokens = old_embeddings.num_embeddings
855
+
856
+ new_embeddings = torch.nn.Embedding(
857
+ new_num_tokens, old_embeddings.embedding_dim
858
+ )
859
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
860
+ old_embeddings.weight.data
861
+ )
862
+
863
+ self.pico_decoder.embedding_proj = new_embeddings
864
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
865
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
866
+ )
867
+
868
+ return new_embeddings
869
+
870
+ @classmethod
871
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
872
+ """
873
+ Load a pretrained model from a checkpoint.
874
+
875
+ This method handles loading from both the old PicoDecoderHF format and the new format.
876
+ """
877
+ # First try to load with the new class
878
+ try:
879
+ return super().from_pretrained(
880
+ pretrained_model_name_or_path, *model_args, **kwargs
881
+ )
882
+ except Exception as e:
883
+ print(f"Failed to load with new class: {e}")
884
+ print("Attempting to load with legacy class and convert...")
885
+
886
+ # Try to load with the old class and convert
887
+ try:
888
+ from transformers import AutoModel
889
+
890
+ old_model = AutoModel.from_pretrained(
891
+ pretrained_model_name_or_path,
892
+ trust_remote_code=True,
893
+ *model_args,
894
+ **kwargs,
895
+ )
896
+
897
+ # Create new model instance
898
+ new_model = cls(old_model.config)
899
+
900
+ # Copy state dict
901
+ new_model.load_state_dict(old_model.state_dict(), strict=False)
902
+
903
+ return new_model
904
+
905
+ except Exception as e2:
906
+ print(f"Failed to convert from legacy format: {e2}")
907
+ raise e
908
+
909
+
910
+ # Register the new class
911
+ PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma250M-v1/checkpoints/step_104000/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }
pico-decoder-tiny-dolma250M-v1/eval_results/step_102000.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"paloma": Infinity}
pico-decoder-tiny-dolma250M-v1/logs/log_20250831_162326.log ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-08-31 17:03:52 - pico-train - INFO - Step 100000 -- 📊 Evaluation Results
2
+ 2025-08-31 17:03:52 - pico-train - INFO - └── paloma: inf
3
+ 2025-08-31 17:03:52 - pico-train - INFO - ==================================================
4
+ 2025-08-31 17:03:52 - pico-train - INFO - ✨ Training Configuration
5
+ 2025-08-31 17:03:52 - pico-train - INFO - ==================================================
6
+ 2025-08-31 17:03:52 - pico-train - INFO - ╭─────────────────────────────────────────────────────╮
7
+ 2025-08-31 17:03:52 - pico-train - INFO - │ checkpointing: │
8
+ 2025-08-31 17:03:52 - pico-train - INFO - │ checkpoints_dir: checkpoints │
9
+ 2025-08-31 17:03:52 - pico-train - INFO - │ evaluation: │
10
+ 2025-08-31 17:03:52 - pico-train - INFO - │ eval_results_dir: eval_results │
11
+ 2025-08-31 17:03:52 - pico-train - INFO - │ fabric_checkpoint_dir: fabric_state │
12
+ 2025-08-31 17:03:52 - pico-train - INFO - │ fabric_checkpoint_filename: checkpoint.pt │
13
+ 2025-08-31 17:03:52 - pico-train - INFO - │ hf_checkpoint: │
14
+ 2025-08-31 17:03:52 - pico-train - INFO - │ collection_slug: null │
15
+ 2025-08-31 17:03:52 - pico-train - INFO - │ repo_id: ThomasTheMaker/pico-decoder-tiny │
16
+ 2025-08-31 17:03:52 - pico-train - INFO - │ learning_dynamics: │
17
+ 2025-08-31 17:03:52 - pico-train - INFO - │ batch_size: 1 │
18
+ 2025-08-31 17:03:52 - pico-train - INFO - │ eval_data: null │
19
+ 2025-08-31 17:03:52 - pico-train - INFO - │ layer_suffixes: │
20
+ 2025-08-31 17:03:52 - pico-train - INFO - │ - attention.v_proj │
21
+ 2025-08-31 17:03:52 - pico-train - INFO - │ - attention.o_proj │
22
+ 2025-08-31 17:03:52 - pico-train - INFO - │ - swiglu.w_2 │
23
+ 2025-08-31 17:03:52 - pico-train - INFO - │ sequence_idx: -1 │
24
+ 2025-08-31 17:03:52 - pico-train - INFO - │ learning_dynamics_dir: learning_dynamics │
25
+ 2025-08-31 17:03:52 - pico-train - INFO - │ logs_dir: logs │
26
+ 2025-08-31 17:03:52 - pico-train - INFO - │ run_name: pico-decoder-tiny-dolma250M-v1 │
27
+ 2025-08-31 17:03:52 - pico-train - INFO - │ runs_dir: runs │
28
+ 2025-08-31 17:03:52 - pico-train - INFO - │ save_every_n_steps: 2000 │
29
+ 2025-08-31 17:03:52 - pico-train - INFO - │ save_to_hf: false │
30
+ 2025-08-31 17:03:52 - pico-train - INFO - │ training: │
31
+ 2025-08-31 17:03:52 - pico-train - INFO - │ auto_resume: true │
32
+ 2025-08-31 17:03:52 - pico-train - INFO - │ data: │
33
+ 2025-08-31 17:03:52 - pico-train - INFO - │ dataloader: │
34
+ 2025-08-31 17:03:52 - pico-train - INFO - │ batch_size: 16 │
35
+ 2025-08-31 17:03:52 - pico-train - INFO - │ dataset: │
36
+ 2025-08-31 17:03:52 - pico-train - INFO - │ name: pico-lm/pretokenized-dolma │
37
+ 2025-08-31 17:03:52 - pico-train - INFO - │ tokenizer: │
38
+ 2025-08-31 17:03:52 - pico-train - INFO - │ name: allenai/OLMo-7B-0724-hf │
39
+ 2025-08-31 17:03:52 - pico-train - INFO - │ vocab_size: 50304 │
40
+ 2025-08-31 17:03:52 - pico-train - INFO - │ evaluation: │
41
+ 2025-08-31 17:03:52 - pico-train - INFO - │ metrics: │
42
+ 2025-08-31 17:03:52 - pico-train - INFO - │ - paloma │
43
+ 2025-08-31 17:03:52 - pico-train - INFO - │ paloma: │
44
+ 2025-08-31 17:03:52 - pico-train - INFO - │ batch_size: 1 │
45
+ 2025-08-31 17:03:52 - pico-train - INFO - │ dataset_name: pico-lm/pretokenized-paloma-tinsy │
46
+ 2025-08-31 17:03:52 - pico-train - INFO - │ dataset_split: val │
47
+ 2025-08-31 17:03:52 - pico-train - INFO - │ max_length: 2048 │
48
+ 2025-08-31 17:03:52 - pico-train - INFO - │ model: │
49
+ 2025-08-31 17:03:52 - pico-train - INFO - │ activation_hidden_dim: 384 │
50
+ 2025-08-31 17:03:52 - pico-train - INFO - │ attention_n_heads: 12 │
51
+ 2025-08-31 17:03:52 - pico-train - INFO - │ attention_n_kv_heads: 4 │
52
+ 2025-08-31 17:03:52 - pico-train - INFO - │ batch_size: 1024 │
53
+ 2025-08-31 17:03:52 - pico-train - INFO - │ d_model: 96 │
54
+ 2025-08-31 17:03:52 - pico-train - INFO - │ max_seq_len: 2048 │
55
+ 2025-08-31 17:03:52 - pico-train - INFO - │ model_type: pico_decoder │
56
+ 2025-08-31 17:03:52 - pico-train - INFO - │ n_layers: 12 │
57
+ 2025-08-31 17:03:52 - pico-train - INFO - │ norm_eps: 1.0e-06 │
58
+ 2025-08-31 17:03:52 - pico-train - INFO - │ position_emb_theta: 10000.0 │
59
+ 2025-08-31 17:03:52 - pico-train - INFO - │ vocab_size: 50304 │
60
+ 2025-08-31 17:03:52 - pico-train - INFO - │ monitoring: │
61
+ 2025-08-31 17:03:52 - pico-train - INFO - │ logging: │
62
+ 2025-08-31 17:03:52 - pico-train - INFO - │ log_every_n_steps: 100 │
63
+ 2025-08-31 17:03:52 - pico-train - INFO - │ log_level: INFO │
64
+ 2025-08-31 17:03:52 - pico-train - INFO - │ save_to_wandb: false │
65
+ 2025-08-31 17:03:52 - pico-train - INFO - │ wandb: │
66
+ 2025-08-31 17:03:52 - pico-train - INFO - │ entity: boymyc │
67
+ 2025-08-31 17:03:52 - pico-train - INFO - │ project: pico-decoder-tiny │
68
+ 2025-08-31 17:03:52 - pico-train - INFO - │ training: │
69
+ 2025-08-31 17:03:52 - pico-train - INFO - │ fabric: │
70
+ 2025-08-31 17:03:52 - pico-train - INFO - │ accelerator: cuda │
71
+ 2025-08-31 17:03:52 - pico-train - INFO - │ num_devices: 1 │
72
+ 2025-08-31 17:03:52 - pico-train - INFO - │ num_nodes: 1 │
73
+ 2025-08-31 17:03:52 - pico-train - INFO - │ precision: bf16-mixed │
74
+ 2025-08-31 17:03:52 - pico-train - INFO - │ max_steps: 100000 │
75
+ 2025-08-31 17:03:52 - pico-train - INFO - │ optimization: │
76
+ 2025-08-31 17:03:52 - pico-train - INFO - │ gradient_accumulation_steps: 1 │
77
+ 2025-08-31 17:03:52 - pico-train - INFO - │ lr: 0.0002 │
78
+ 2025-08-31 17:03:52 - pico-train - INFO - │ lr_scheduler: cosine │
79
+ 2025-08-31 17:03:52 - pico-train - INFO - │ lr_warmup_steps: 2000 │
80
+ 2025-08-31 17:03:52 - pico-train - INFO - │ optimizer: adamw │
81
+ 2025-08-31 17:03:52 - pico-train - INFO - │ │
82
+ 2025-08-31 17:03:52 - pico-train - INFO - ╰─────────────────────────────────────────────────────╯
83
+ 2025-08-31 17:03:52 - pico-train - INFO - ==================================================
84
+ 2025-08-31 17:03:52 - pico-train - INFO - ⛭ Runtime Summary:
85
+ 2025-08-31 17:03:52 - pico-train - INFO - ==================================================
86
+ 2025-08-31 17:03:52 - pico-train - INFO - Starting from step: 100000
87
+ 2025-08-31 17:03:52 - pico-train - INFO - Model Setup:
88
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ Total Parameters: 11,282,784
89
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ Trainable Parameters: 11,282,784
90
+ 2025-08-31 17:03:52 - pico-train - INFO - Distributed Setup:
91
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ Number of Devices: 1
92
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ Device Type: NVIDIA H100 80GB HBM3
93
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ Available Memory: 85.03 GB
94
+ 2025-08-31 17:03:52 - pico-train - INFO - Software Setup:
95
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ Python Version: 3.12.3
96
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ PyTorch Version: 2.8.0+cu128
97
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ CUDA Version: 12.8
98
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ Operating System: Linux 6.8.0-71-generic
99
+ 2025-08-31 17:03:52 - pico-train - INFO - Batch Size Configuration:
100
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ Global Batch Size: 16
101
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ Per Device Batch Size: 16
102
+ 2025-08-31 17:03:52 - pico-train - INFO - └─ Gradient Accumulation Steps: 1
103
+ 2025-08-31 17:03:52 - pico-train - INFO - ==================================================
104
+ 2025-08-31 17:03:52 - pico-train - INFO - Step 100000 -- 🔄 Training Metrics
105
+ 2025-08-31 17:03:52 - pico-train - INFO - ├── Loss: 4.9432
106
+ 2025-08-31 17:03:52 - pico-train - INFO - ├── Learning Rate: 2.00e-05
107
+ 2025-08-31 17:03:52 - pico-train - INFO - └── Inf/NaN count: 0
108
+ 2025-08-31 17:03:52 - pico-train - INFO - Step 100000 -- 📈 Saving Learning Dynamics
109
+ 2025-08-31 17:04:49 - pico-train - INFO - Step 100100 -- 🔄 Training Metrics
110
+ 2025-08-31 17:04:49 - pico-train - INFO - ├── Loss: 4.7703
111
+ 2025-08-31 17:04:49 - pico-train - INFO - ├── Learning Rate: 1.01e-04
112
+ 2025-08-31 17:04:49 - pico-train - INFO - └── Inf/NaN count: 0
113
+ 2025-08-31 17:05:43 - pico-train - INFO - Step 100200 -- 🔄 Training Metrics
114
+ 2025-08-31 17:05:43 - pico-train - INFO - ├── Loss: 4.8047
115
+ 2025-08-31 17:05:43 - pico-train - INFO - ├── Learning Rate: 1.01e-04
116
+ 2025-08-31 17:05:43 - pico-train - INFO - └── Inf/NaN count: 0
117
+ 2025-08-31 17:06:37 - pico-train - INFO - Step 100300 -- 🔄 Training Metrics
118
+ 2025-08-31 17:06:37 - pico-train - INFO - ├── Loss: 4.8076
119
+ 2025-08-31 17:06:37 - pico-train - INFO - ├── Learning Rate: 1.01e-04
120
+ 2025-08-31 17:06:37 - pico-train - INFO - └── Inf/NaN count: 0
121
+ 2025-08-31 17:07:31 - pico-train - INFO - Step 100400 -- 🔄 Training Metrics
122
+ 2025-08-31 17:07:31 - pico-train - INFO - ├── Loss: 4.7926
123
+ 2025-08-31 17:07:31 - pico-train - INFO - ├── Learning Rate: 1.01e-04
124
+ 2025-08-31 17:07:31 - pico-train - INFO - └── Inf/NaN count: 0
125
+ 2025-08-31 17:08:25 - pico-train - INFO - Step 100500 -- 🔄 Training Metrics
126
+ 2025-08-31 17:08:25 - pico-train - INFO - ├── Loss: 4.8059
127
+ 2025-08-31 17:08:25 - pico-train - INFO - ├── Learning Rate: 1.01e-04
128
+ 2025-08-31 17:08:25 - pico-train - INFO - └── Inf/NaN count: 0
129
+ 2025-08-31 17:09:19 - pico-train - INFO - Step 100600 -- 🔄 Training Metrics
130
+ 2025-08-31 17:09:19 - pico-train - INFO - ├── Loss: 4.7896
131
+ 2025-08-31 17:09:19 - pico-train - INFO - ├── Learning Rate: 1.01e-04
132
+ 2025-08-31 17:09:19 - pico-train - INFO - └── Inf/NaN count: 0
133
+ 2025-08-31 17:10:12 - pico-train - INFO - Step 100700 -- 🔄 Training Metrics
134
+ 2025-08-31 17:10:12 - pico-train - INFO - ├── Loss: 4.8066
135
+ 2025-08-31 17:10:12 - pico-train - INFO - ├── Learning Rate: 1.00e-04
136
+ 2025-08-31 17:10:12 - pico-train - INFO - └── Inf/NaN count: 0
137
+ 2025-08-31 17:11:07 - pico-train - INFO - Step 100800 -- 🔄 Training Metrics
138
+ 2025-08-31 17:11:07 - pico-train - INFO - ├── Loss: 4.7870
139
+ 2025-08-31 17:11:07 - pico-train - INFO - ├── Learning Rate: 1.00e-04
140
+ 2025-08-31 17:11:07 - pico-train - INFO - └── Inf/NaN count: 0
141
+ 2025-08-31 17:12:01 - pico-train - INFO - Step 100900 -- 🔄 Training Metrics
142
+ 2025-08-31 17:12:01 - pico-train - INFO - ├── Loss: 4.7958
143
+ 2025-08-31 17:12:01 - pico-train - INFO - ├── Learning Rate: 1.00e-04
144
+ 2025-08-31 17:12:01 - pico-train - INFO - └── Inf/NaN count: 0
145
+ 2025-08-31 17:12:55 - pico-train - INFO - Step 101000 -- 🔄 Training Metrics
146
+ 2025-08-31 17:12:55 - pico-train - INFO - ├── Loss: 4.8081
147
+ 2025-08-31 17:12:55 - pico-train - INFO - ├── Learning Rate: 1.00e-04
148
+ 2025-08-31 17:12:55 - pico-train - INFO - └── Inf/NaN count: 0
149
+ 2025-08-31 17:13:48 - pico-train - INFO - Step 101100 -- 🔄 Training Metrics
150
+ 2025-08-31 17:13:48 - pico-train - INFO - ├── Loss: 4.8023
151
+ 2025-08-31 17:13:48 - pico-train - INFO - ├── Learning Rate: 9.98e-05
152
+ 2025-08-31 17:13:48 - pico-train - INFO - └── Inf/NaN count: 0
153
+ 2025-08-31 17:14:43 - pico-train - INFO - Step 101200 -- 🔄 Training Metrics
154
+ 2025-08-31 17:14:43 - pico-train - INFO - ├── Loss: 4.7830
155
+ 2025-08-31 17:14:43 - pico-train - INFO - ├── Learning Rate: 9.97e-05
156
+ 2025-08-31 17:14:43 - pico-train - INFO - └── Inf/NaN count: 0
157
+ 2025-08-31 17:15:38 - pico-train - INFO - Step 101300 -- 🔄 Training Metrics
158
+ 2025-08-31 17:15:38 - pico-train - INFO - ├── Loss: 4.8071
159
+ 2025-08-31 17:15:38 - pico-train - INFO - ├── Learning Rate: 9.95e-05
160
+ 2025-08-31 17:15:38 - pico-train - INFO - └── Inf/NaN count: 0
161
+ 2025-08-31 17:16:32 - pico-train - INFO - Step 101400 -- 🔄 Training Metrics
162
+ 2025-08-31 17:16:32 - pico-train - INFO - ├── Loss: 4.8072
163
+ 2025-08-31 17:16:32 - pico-train - INFO - ├── Learning Rate: 9.94e-05
164
+ 2025-08-31 17:16:32 - pico-train - INFO - └── Inf/NaN count: 0
165
+ 2025-08-31 17:17:27 - pico-train - INFO - Step 101500 -- 🔄 Training Metrics
166
+ 2025-08-31 17:17:27 - pico-train - INFO - ├── Loss: 4.8027
167
+ 2025-08-31 17:17:27 - pico-train - INFO - ├── Learning Rate: 9.92e-05
168
+ 2025-08-31 17:17:27 - pico-train - INFO - └── Inf/NaN count: 0
169
+ 2025-08-31 17:18:20 - pico-train - INFO - Step 101600 -- 🔄 Training Metrics
170
+ 2025-08-31 17:18:20 - pico-train - INFO - ├── Loss: 4.7874
171
+ 2025-08-31 17:18:20 - pico-train - INFO - ├── Learning Rate: 9.90e-05
172
+ 2025-08-31 17:18:20 - pico-train - INFO - └── Inf/NaN count: 0
173
+ 2025-08-31 17:19:15 - pico-train - INFO - Step 101700 -- 🔄 Training Metrics
174
+ 2025-08-31 17:19:15 - pico-train - INFO - ├���─ Loss: 4.7817
175
+ 2025-08-31 17:19:15 - pico-train - INFO - ├── Learning Rate: 9.89e-05
176
+ 2025-08-31 17:19:15 - pico-train - INFO - └── Inf/NaN count: 0
177
+ 2025-08-31 17:20:09 - pico-train - INFO - Step 101800 -- 🔄 Training Metrics
178
+ 2025-08-31 17:20:09 - pico-train - INFO - ├── Loss: 4.8188
179
+ 2025-08-31 17:20:09 - pico-train - INFO - ├── Learning Rate: 9.87e-05
180
+ 2025-08-31 17:20:09 - pico-train - INFO - └── Inf/NaN count: 0
181
+ 2025-08-31 17:21:04 - pico-train - INFO - Step 101900 -- 🔄 Training Metrics
182
+ 2025-08-31 17:21:04 - pico-train - INFO - ├── Loss: 4.7880
183
+ 2025-08-31 17:21:04 - pico-train - INFO - ├── Learning Rate: 9.86e-05
184
+ 2025-08-31 17:21:04 - pico-train - INFO - └── Inf/NaN count: 0
185
+ 2025-08-31 17:21:58 - pico-train - INFO - Step 102000 -- 💾 Saving Checkpoint
186
+ 2025-08-31 18:00:17 - pico-train - INFO - Step 102000 -- 📊 Evaluation Results
187
+ 2025-08-31 18:00:17 - pico-train - INFO - └── paloma: inf
188
+ 2025-08-31 18:00:17 - pico-train - INFO - Step 102000 -- 🔄 Training Metrics
189
+ 2025-08-31 18:00:17 - pico-train - INFO - ├── Loss: 4.8055
190
+ 2025-08-31 18:00:17 - pico-train - INFO - ├── Learning Rate: 9.84e-05
191
+ 2025-08-31 18:00:17 - pico-train - INFO - └── Inf/NaN count: 0
192
+ 2025-08-31 18:00:17 - pico-train - INFO - Step 102000 -- 📈 Saving Learning Dynamics
193
+ 2025-08-31 18:01:13 - pico-train - INFO - Step 102100 -- 🔄 Training Metrics
194
+ 2025-08-31 18:01:13 - pico-train - INFO - ├── Loss: 4.7742
195
+ 2025-08-31 18:01:13 - pico-train - INFO - ├── Learning Rate: 9.83e-05
196
+ 2025-08-31 18:01:13 - pico-train - INFO - └── Inf/NaN count: 0
197
+ 2025-08-31 18:02:07 - pico-train - INFO - Step 102200 -- 🔄 Training Metrics
198
+ 2025-08-31 18:02:07 - pico-train - INFO - ├── Loss: 4.8050
199
+ 2025-08-31 18:02:07 - pico-train - INFO - ├── Learning Rate: 9.81e-05
200
+ 2025-08-31 18:02:07 - pico-train - INFO - └── Inf/NaN count: 0
201
+ 2025-08-31 18:03:01 - pico-train - INFO - Step 102300 -- 🔄 Training Metrics
202
+ 2025-08-31 18:03:01 - pico-train - INFO - ├── Loss: 4.8066
203
+ 2025-08-31 18:03:01 - pico-train - INFO - ├── Learning Rate: 9.79e-05
204
+ 2025-08-31 18:03:01 - pico-train - INFO - └── Inf/NaN count: 0
205
+ 2025-08-31 18:03:57 - pico-train - INFO - Step 102400 -- 🔄 Training Metrics
206
+ 2025-08-31 18:03:57 - pico-train - INFO - ├── Loss: 4.7865
207
+ 2025-08-31 18:03:57 - pico-train - INFO - ├── Learning Rate: 9.78e-05
208
+ 2025-08-31 18:03:57 - pico-train - INFO - └── Inf/NaN count: 0
209
+ 2025-08-31 18:04:50 - pico-train - INFO - Step 102500 -- 🔄 Training Metrics
210
+ 2025-08-31 18:04:50 - pico-train - INFO - ├── Loss: 4.8019
211
+ 2025-08-31 18:04:50 - pico-train - INFO - ├── Learning Rate: 9.76e-05
212
+ 2025-08-31 18:04:50 - pico-train - INFO - └── Inf/NaN count: 0
213
+ 2025-08-31 18:05:45 - pico-train - INFO - Step 102600 -- 🔄 Training Metrics
214
+ 2025-08-31 18:05:45 - pico-train - INFO - ├── Loss: 4.7948
215
+ 2025-08-31 18:05:45 - pico-train - INFO - ├── Learning Rate: 9.75e-05
216
+ 2025-08-31 18:05:45 - pico-train - INFO - └── Inf/NaN count: 0
217
+ 2025-08-31 18:06:39 - pico-train - INFO - Step 102700 -- 🔄 Training Metrics
218
+ 2025-08-31 18:06:39 - pico-train - INFO - ├── Loss: 4.8006
219
+ 2025-08-31 18:06:39 - pico-train - INFO - ├── Learning Rate: 9.73e-05
220
+ 2025-08-31 18:06:39 - pico-train - INFO - └── Inf/NaN count: 0
221
+ 2025-08-31 18:07:33 - pico-train - INFO - Step 102800 -- 🔄 Training Metrics
222
+ 2025-08-31 18:07:33 - pico-train - INFO - ├── Loss: 4.8049
223
+ 2025-08-31 18:07:33 - pico-train - INFO - ├── Learning Rate: 9.71e-05
224
+ 2025-08-31 18:07:33 - pico-train - INFO - └── Inf/NaN count: 0
225
+ 2025-08-31 18:08:27 - pico-train - INFO - Step 102900 -- 🔄 Training Metrics
226
+ 2025-08-31 18:08:27 - pico-train - INFO - ├── Loss: 4.8086
227
+ 2025-08-31 18:08:27 - pico-train - INFO - ├── Learning Rate: 9.70e-05
228
+ 2025-08-31 18:08:27 - pico-train - INFO - └── Inf/NaN count: 0
229
+ 2025-08-31 18:09:21 - pico-train - INFO - Step 103000 -- 🔄 Training Metrics
230
+ 2025-08-31 18:09:21 - pico-train - INFO - ├── Loss: 4.8154
231
+ 2025-08-31 18:09:21 - pico-train - INFO - ├── Learning Rate: 9.68e-05
232
+ 2025-08-31 18:09:21 - pico-train - INFO - └── Inf/NaN count: 0
233
+ 2025-08-31 18:10:15 - pico-train - INFO - Step 103100 -- 🔄 Training Metrics
234
+ 2025-08-31 18:10:15 - pico-train - INFO - ├── Loss: 4.8232
235
+ 2025-08-31 18:10:15 - pico-train - INFO - ├── Learning Rate: 9.67e-05
236
+ 2025-08-31 18:10:15 - pico-train - INFO - └── Inf/NaN count: 0
237
+ 2025-08-31 18:11:10 - pico-train - INFO - Step 103200 -- 🔄 Training Metrics
238
+ 2025-08-31 18:11:10 - pico-train - INFO - ├── Loss: 4.8032
239
+ 2025-08-31 18:11:10 - pico-train - INFO - ├── Learning Rate: 9.65e-05
240
+ 2025-08-31 18:11:10 - pico-train - INFO - └── Inf/NaN count: 0
241
+ 2025-08-31 18:12:05 - pico-train - INFO - Step 103300 -- 🔄 Training Metrics
242
+ 2025-08-31 18:12:05 - pico-train - INFO - ├── Loss: 4.8157
243
+ 2025-08-31 18:12:05 - pico-train - INFO - ├── Learning Rate: 9.64e-05
244
+ 2025-08-31 18:12:05 - pico-train - INFO - └── Inf/NaN count: 0
245
+ 2025-08-31 18:13:00 - pico-train - INFO - Step 103400 -- 🔄 Training Metrics
246
+ 2025-08-31 18:13:00 - pico-train - INFO - ├── Loss: 4.7903
247
+ 2025-08-31 18:13:00 - pico-train - INFO - ├── Learning Rate: 9.62e-05
248
+ 2025-08-31 18:13:00 - pico-train - INFO - └── Inf/NaN count: 0
249
+ 2025-08-31 18:13:54 - pico-train - INFO - Step 103500 -- 🔄 Training Metrics
250
+ 2025-08-31 18:13:54 - pico-train - INFO - ├── Loss: 4.7786
251
+ 2025-08-31 18:13:54 - pico-train - INFO - ├── Learning Rate: 9.60e-05
252
+ 2025-08-31 18:13:54 - pico-train - INFO - └── Inf/NaN count: 0
253
+ 2025-08-31 18:14:48 - pico-train - INFO - Step 103600 -- 🔄 Training Metrics
254
+ 2025-08-31 18:14:48 - pico-train - INFO - ├── Loss: 4.7962
255
+ 2025-08-31 18:14:48 - pico-train - INFO - ├── Learning Rate: 9.59e-05
256
+ 2025-08-31 18:14:48 - pico-train - INFO - └── Inf/NaN count: 0
257
+ 2025-08-31 18:15:43 - pico-train - INFO - Step 103700 -- 🔄 Training Metrics
258
+ 2025-08-31 18:15:43 - pico-train - INFO - ├── Loss: 4.8097
259
+ 2025-08-31 18:15:43 - pico-train - INFO - ├── Learning Rate: 9.57e-05
260
+ 2025-08-31 18:15:43 - pico-train - INFO - └── Inf/NaN count: 0
261
+ 2025-08-31 18:16:37 - pico-train - INFO - Step 103800 -- 🔄 Training Metrics
262
+ 2025-08-31 18:16:37 - pico-train - INFO - ├── Loss: 4.7613
263
+ 2025-08-31 18:16:37 - pico-train - INFO - ├── Learning Rate: 9.56e-05
264
+ 2025-08-31 18:16:37 - pico-train - INFO - └── Inf/NaN count: 0
265
+ 2025-08-31 18:17:31 - pico-train - INFO - Step 103900 -- 🔄 Training Metrics
266
+ 2025-08-31 18:17:31 - pico-train - INFO - ├── Loss: 4.7992
267
+ 2025-08-31 18:17:31 - pico-train - INFO - ├── Learning Rate: 9.54e-05
268
+ 2025-08-31 18:17:31 - pico-train - INFO - └── Inf/NaN count: 0
269
+ 2025-08-31 18:18:25 - pico-train - INFO - Step 104000 -- 💾 Saving Checkpoint