bkhmsi commited on
Commit
582ea12
·
1 Parent(s): 4cafde7

created micro hf space

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .env
2
+ *.pyc
3
+ .DS_Store
4
+ __pycache__/
Instructions.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MiCRo Expert Routing Visualizer (Gradio)
2
+
3
+ This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
4
+ Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model’s internal reasoning structure.
5
+
6
+ ## How it works
7
+ - Choose a model (dropdown) or type a custom model id.
8
+ - Enter a *User prompt*. Optionally add an *Assistant prompt*; if provided, the app concatenates them as:
9
+
10
+ ```
11
+ User: <user text>
12
+ Assistant: <assistant text>
13
+ ```
14
+
15
+ - When the prompt fails, the demo falls back to "mock data", which generates deterministic, pseudo-random percentages from the prompt.
16
+
17
+ ### Backend contract
18
+ `get_expert_routing(model_id: str, prompt: str)` must return 4 values (percentages) for the experts in this fixed order:
19
+ `["Language", "Logic", "Social", "World"]`
20
+ or a dict with those exact keys.
README.md CHANGED
@@ -1,6 +1,5 @@
1
- ---
2
  title: MiCRo Routing Visualizer
3
- emoji: 💻
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
@@ -8,7 +7,4 @@ sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: Mixture of Cognitive Reasoners Computation Allocation
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  title: MiCRo Routing Visualizer
2
+ emoji: 🧠
3
  colorFrom: purple
4
  colorTo: red
5
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
  license: mit
10
+ short_description: Mixture of Cognitive Reasoners Computation Allocation
 
 
 
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ """
3
+ Hugging Face Space: MoE Expert Routing Visualizer (Gradio)
4
+ ----------------------------------------------------------
5
+ This Space lets a user:
6
+ - Choose a model (from a dropdown or a free-text box)
7
+ - Enter a user prompt, and optionally an assistant prompt
8
+ - Call a backend function that returns 4 routing percentages (Language, Logic, Social, World)
9
+ - See a bar plot + table of the percentages
10
+
11
+ 🧩 Plug your real routing function in router_backend.py -> get_expert_routing().
12
+ By default, a deterministic "mock mode" produces stable pseudo-random percentages from the prompt.
13
+ """
14
+
15
+ import hashlib
16
+ from typing import Dict, List, Tuple, Union
17
+ import gradio as gr
18
+ import plotly
19
+ import plotly.express as px
20
+ import pandas as pd
21
+ from router_backend import get_expert_routing
22
+
23
+ # ---- Expected backend adapter ------------------------------------------------
24
+ # Implement your real function in router_backend.py with the following signature:
25
+ # def get_expert_routing(model_id: str, prompt: str) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]
26
+ # It MUST return 4 values that sum to ~100 (percentages) in the fixed order:
27
+ # ["Language", "Logic", "Social", "World"]
28
+ # or a mapping with those keys.
29
+ # try:
30
+ # from router_backend import get_expert_routing # your real backend
31
+ # BACKEND_AVAILABLE = True
32
+ # except Exception as e: # keep error for display if needed
33
+ # BACKEND_AVAILABLE = False
34
+ # _backend_import_error = e
35
+
36
+ EXPERTS = ["Language", "Logic", "Social", "World"]
37
+
38
+ DEFAULT_MODELS = [
39
+ "micro-llama-1b",
40
+ "micro-llama-3b",
41
+ "micro-llama-1b-dpo",
42
+ "micro-moe-llama-1b",
43
+ "micro-smollm2-135m",
44
+ "micro-smollm2-360m",
45
+ "micro-moe-smollm2-135m",
46
+ "micro-moe-smollm2-360m",
47
+ ]
48
+
49
+ def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]:
50
+ """
51
+ Deterministic mock routing percentages based on model_id + prompt + seed.
52
+ Returns a list of 4 percentages summing to 100.0
53
+ """
54
+ h = hashlib.sha256(f"{model_id}||{prompt}||{seed}".encode()).digest()
55
+ # split into 4 positive numbers
56
+ vals = [int.from_bytes(h[i*8:(i+1)*8], "little") % 10_000 + 1 for i in range(4)]
57
+ s = sum(vals)
58
+ return [100.0 * v / s for v in vals]
59
+
60
+ def _normalize_output(r: Union[List[float], Tuple[float, float, float, float], Dict[str, float]]) -> List[float]:
61
+ """
62
+ Normalize different return types into a 4-length list ordered as EXPERTS.
63
+ """
64
+ if isinstance(r, dict):
65
+ vals = [float(r.get(k, 0.0)) for k in EXPERTS]
66
+ else:
67
+ vals = [float(x) for x in list(r)]
68
+ if len(vals) != 4:
69
+ raise ValueError(f"Expected 4 values, got {len(vals)}.")
70
+ # renormalize to 100 if needed
71
+ s = sum(vals)
72
+ if s <= 0:
73
+ raise ValueError("Sum of routing percentages is non-positive.")
74
+ vals = [100.0 * v / s for v in vals]
75
+ return vals
76
+
77
+ def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str:
78
+ user_prompt = (user_prompt or "").strip()
79
+ assistant_prompt = (assistant_prompt or "").strip()
80
+ if assistant_prompt:
81
+ return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
82
+ return user_prompt
83
+
84
+ def route_and_plot(model_choice: str, hf_token: str, user_prompt: str, assistant_prompt: str) -> Tuple[pd.DataFrame, "plotly.graph_objs._figure.Figure", str]:
85
+ """
86
+ Main pipeline:
87
+ - Compose prompt (user + optional assistant)
88
+ - Call backend (real or mock)
89
+ - Return a table + bar plot + status message
90
+ """
91
+ model_id = model_choice.strip()
92
+ if not model_id:
93
+ raise gr.Error("Please select a model or enter a custom model id.")
94
+ prompt = _compose_prompt(user_prompt, assistant_prompt)
95
+ if not prompt:
96
+ raise gr.Error("Please enter a prompt.")
97
+
98
+ seed = 42
99
+ use_mock = False
100
+ if use_mock:
101
+ msg = "Using mock data."
102
+ vals = _mock_routing(model_id, prompt, seed=seed)
103
+ generation = None
104
+ else:
105
+ try:
106
+ raw, generation = get_expert_routing(model_id, hf_token, prompt) # <-- your real function
107
+ vals = _normalize_output(raw)
108
+ msg = "Routed with real backend."
109
+ except Exception as e:
110
+ # fallback to mock on error, but surface message
111
+ msg = f"Backend error: {e}\nFalling back to mock data."
112
+ vals = _mock_routing(model_id, prompt, seed=seed)
113
+ generation = None
114
+
115
+ df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
116
+ fig = px.bar(df, x="Expert", y="Percent", title="Token Routing by Expert (%)", text="Percent")
117
+ fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside")
118
+ fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35)
119
+
120
+ status = f"Model: {model_id}<br>{msg}"
121
+ if generation is None:
122
+ generation = assistant_prompt
123
+
124
+ return generation, df, fig, status
125
+
126
+ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
127
+ gr.Markdown(
128
+ """
129
+ # 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer
130
+ ## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts.
131
+ Paper: [Mixture of Cognitive Reasoners: Modular Reasoning with Brain-Like Specialization](https://arxiv.org/abs/2506.13331)
132
+ ----
133
+ This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
134
+ Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure.
135
+ """.strip()
136
+ )
137
+
138
+ with gr.Row():
139
+ model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
140
+ hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="hf token", lines=1)
141
+
142
+ with gr.Row():
143
+ user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
144
+ assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...")
145
+
146
+ # with gr.Row():
147
+ # use_mock = gr.Checkbox(value=True, label="Use mock data (uncheck to call your backend)")
148
+ # seed = gr.Slider(value=0, minimum=0, maximum=10_000, step=1, label="Mock seed")
149
+
150
+ run = gr.Button("Run Routing", variant="primary")
151
+
152
+ generation_output = gr.Textbox(lines=4, label="Generated continuation", placeholder="Generated text will appear here...", interactive=False)
153
+
154
+ with gr.Row():
155
+ table = gr.Dataframe(label="Routing Percentages", interactive=False)
156
+ plot = gr.Plot(label="Bar Plot")
157
+ status = gr.Markdown("")
158
+
159
+ run.click(
160
+ route_and_plot,
161
+ inputs=[model_choice, hf_token, user_prompt, assistant_prompt],
162
+ outputs=[generation_output, table, plot, status],
163
+ )
164
+
165
+ if __name__ == "__main__":
166
+ demo.launch()
configs/micro_llama_1b.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run-title: micro-llama-1b
2
+ model: micro-llama-1b
3
+
4
+ base-model: meta-llama/Llama-3.2-1B
5
+ tokenizer: meta-llama/Llama-3.2-1B-Instruct
6
+ num-experts: 4
7
+ top-k-experts: 1
8
+ jitter-noise: 0
9
+ use-router: True
10
+ mask-input: True
11
+ max-length: 8192
12
+
13
+ trainable:
14
+ - model
models/micro_llama.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union, List, Callable
2
+ import logging
3
+ import yaml
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.distributed as dist
9
+
10
+ # from transformers.utils import TransformerKwargs
11
+ from transformers import LlamaConfig, AutoConfig, AutoTokenizer, AutoModelForCausalLM
12
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
13
+ from transformers.models.llama.modeling_llama import (
14
+ LlamaRotaryEmbedding,
15
+ LlamaRMSNorm,
16
+ LlamaMLP,
17
+ LlamaDecoderLayer,
18
+ LlamaPreTrainedModel,
19
+ GenerationMixin,
20
+ apply_rotary_pos_emb,
21
+ eager_attention_forward,
22
+
23
+ )
24
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
25
+ from transformers.cache_utils import Cache, StaticCache, DynamicCache
26
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
27
+ from transformers.processing_utils import Unpack
28
+ from transformers.utils import is_torchdynamo_compiling
29
+ from models.modules import CausalLMOutputWithPast
30
+ from transformers.modeling_layers import GradientCheckpointingLayer
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ def _prepare_4d_causal_attention_mask_with_cache_position(
35
+ attention_mask: torch.Tensor,
36
+ sequence_length: int,
37
+ target_length: int,
38
+ dtype: torch.dtype,
39
+ device: torch.device,
40
+ min_dtype: float,
41
+ cache_position: torch.Tensor,
42
+ batch_size: int,
43
+ ):
44
+ """
45
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
46
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
47
+
48
+ Args:
49
+ attention_mask (`torch.Tensor`):
50
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
51
+ sequence_length (`int`):
52
+ The sequence length being processed.
53
+ target_length (`int`):
54
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
55
+ dtype (`torch.dtype`):
56
+ The dtype to use for the 4D attention mask.
57
+ device (`torch.device`):
58
+ The device to plcae the 4D attention mask on.
59
+ min_dtype (`float`):
60
+ The minimum value representable with the dtype `dtype`.
61
+ cache_position (`torch.Tensor`):
62
+ Indices depicting the position of the input sequence tokens in the sequence.
63
+ batch_size (`torch.Tensor`):
64
+ Batch size.
65
+ """
66
+ if attention_mask is not None and attention_mask.dim() == 4:
67
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
68
+ causal_mask = attention_mask
69
+ else:
70
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
71
+ if sequence_length != 1:
72
+ causal_mask = torch.triu(causal_mask, diagonal=1)
73
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
74
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
75
+ if attention_mask is not None:
76
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
77
+ mask_length = attention_mask.shape[-1]
78
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
79
+ padding_mask = padding_mask == 0
80
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
81
+ padding_mask, min_dtype
82
+ )
83
+
84
+ return causal_mask
85
+
86
+ class MiCRoLlamaConfig(LlamaConfig):
87
+ model_type = "micro_llama"
88
+ def __init__(self, *args, **kwargs):
89
+ super().__init__(*args, **kwargs)
90
+ self.num_experts = kwargs.get("num_experts", 4)
91
+ self.use_router = kwargs.get("use_router", True)
92
+ self.num_experts_per_tok = kwargs.get("num_experts_per_tok", 2)
93
+ self.jitter_noise = kwargs.get("jitter_noise", 0.0)
94
+ self.loss_method = kwargs.get("loss_method", "all")
95
+ self.config_path = kwargs.get("config_path", None)
96
+
97
+ class MiCRoLlamaDecoderLayer(nn.Module):
98
+ def __init__(self, config: MiCRoLlamaConfig, layer_idx: int):
99
+ super().__init__()
100
+ self.hidden_dim = config.hidden_size
101
+ self.ffn_dim = config.intermediate_size
102
+ self.num_experts = config.num_experts
103
+ self.top_k = config.num_experts_per_tok
104
+ self.use_router = config.use_router
105
+ self.ablate = config.ablate
106
+ self.num_key_value_heads = config.num_key_value_heads
107
+ self.head_dim = self.hidden_dim // config.num_attention_heads
108
+ self.gradient_checkpointing = config.gradient_checkpointing
109
+ if isinstance(self.ablate, str):
110
+ self.ablate = [self.ablate]
111
+
112
+ self.gate = nn.Sequential(
113
+ nn.Linear(self.hidden_dim, self.hidden_dim, bias=False),
114
+ nn.Linear(self.hidden_dim, self.num_experts, bias=False)
115
+ )
116
+
117
+ self.num_layers = config.backbone_num_layers
118
+ self.layer_idx = layer_idx
119
+
120
+ self.experts = nn.ModuleList([LlamaDecoderLayer(config, layer_idx * self.num_experts + expert_idx) for expert_idx in range(self.num_experts)])
121
+
122
+ self.jitter_noise = config.jitter_noise
123
+
124
+ def forward(
125
+ self,
126
+ hidden_states: torch.Tensor,
127
+ routing_weights: Optional[torch.Tensor] = None,
128
+ attention_mask: Optional[torch.Tensor] = None,
129
+ position_ids: Optional[torch.LongTensor] = None,
130
+ ablate: Optional[List[str]] = None,
131
+ past_key_value: Optional[Cache] = None,
132
+ output_attentions: Optional[bool] = False,
133
+ use_cache: Optional[bool] = False,
134
+ cache_position: Optional[torch.LongTensor] = None,
135
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
136
+ **kwargs,
137
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
138
+
139
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
140
+
141
+ if ablate is not None:
142
+ self.ablate = ablate
143
+
144
+ if self.training and self.jitter_noise > 0:
145
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
146
+
147
+ if self.use_router:
148
+ router_logits = self.gate(hidden_states)
149
+ if "logic" in self.ablate:
150
+ router_logits[..., 0] = -torch.inf
151
+ if "social" in self.ablate:
152
+ router_logits[..., 1] = -torch.inf
153
+ if "world" in self.ablate:
154
+ router_logits[..., 2] = -torch.inf
155
+ if "language" in self.ablate:
156
+ router_logits[..., 3] = -torch.inf
157
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
158
+ else:
159
+ if len(routing_weights.shape) == 2:
160
+ routing_weights = routing_weights.unsqueeze(1).tile((1,sequence_length,1)).float()
161
+ else:
162
+ routing_weights = routing_weights.float()
163
+ router_logits = routing_weights
164
+
165
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
166
+ routing_weights /= (routing_weights.sum(dim=-1, keepdim=True) + 1e-9)
167
+
168
+ # we cast back to the input dtype
169
+ routing_weights = routing_weights.to(hidden_states.dtype)
170
+
171
+ # We'll accumulate outputs here
172
+ final_hidden_states = torch.zeros_like(hidden_states)
173
+
174
+ # Flatten final_hidden_states to [batch_size * seq_len, hidden_dim]
175
+ # so we can do a 2D "index_add_" at the end of each loop.
176
+ final_hidden_states_2d = final_hidden_states.view(-1, hidden_dim)
177
+
178
+ # One hot encode the selected experts to create an expert mask
179
+ # this will be used to easily index which expert is going to be sollicitated
180
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
181
+ #^ [batch_size, seq_len, top_k, num_experts]
182
+
183
+ # Loop over all available experts in the model and perform the computation on each expert
184
+ for expert_idx in range(self.num_experts):
185
+ expert_layer: LlamaDecoderLayer = self.experts[expert_idx]
186
+ batch_indices, seq_indices, top_k_indices = torch.where(expert_mask[..., expert_idx])
187
+
188
+ if not self.training and sequence_length == 1 and batch_indices.numel() == 0:
189
+ if past_key_value is not None:
190
+
191
+ hidden_state_ln_norm = expert_layer.input_layernorm(hidden_states)
192
+
193
+ input_shape = hidden_state_ln_norm.shape[:-1]
194
+ hidden_shape = (*input_shape, -1, self.head_dim)
195
+
196
+ # query_states = expert_layer.self_attn.q_proj(hidden_state_ln_norm).view(hidden_shape).transpose(1, 2)
197
+ key_states = expert_layer.self_attn.k_proj(hidden_state_ln_norm).view(hidden_shape).transpose(1, 2)
198
+ value_states = expert_layer.self_attn.v_proj(hidden_state_ln_norm).view(hidden_shape).transpose(1, 2)
199
+
200
+ cos, sin = position_embeddings
201
+ _, key_states = apply_rotary_pos_emb(key_states, key_states, cos, sin)
202
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
203
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
204
+ past_key_value.update(key_states, value_states, self.layer_idx * self.num_experts + expert_idx, cache_kwargs)
205
+
206
+ continue
207
+
208
+ if self.gradient_checkpointing and self.training:
209
+ current_hidden_states = self._gradient_checkpointing_func(
210
+ expert_layer.__call__,
211
+ hidden_states,
212
+ attention_mask,
213
+ position_ids,
214
+ past_key_value,
215
+ output_attentions,
216
+ use_cache,
217
+ cache_position,
218
+ position_embeddings,
219
+ )[0]
220
+ else:
221
+ current_hidden_states = expert_layer(
222
+ hidden_states=hidden_states,
223
+ attention_mask=attention_mask,
224
+ position_ids=position_ids,
225
+ past_key_value=past_key_value,
226
+ output_attentions=output_attentions,
227
+ use_cache=use_cache,
228
+ cache_position=cache_position,
229
+ position_embeddings=position_embeddings,
230
+ **kwargs,
231
+ )[0]
232
+
233
+
234
+ flat_idx = batch_indices * sequence_length + seq_indices
235
+ expert_weights = routing_weights[batch_indices, seq_indices, top_k_indices].unsqueeze(-1)
236
+ current_hidden_states = current_hidden_states[batch_indices, seq_indices] * expert_weights
237
+
238
+ final_hidden_states_2d.index_add_(0, flat_idx, current_hidden_states.to(hidden_states.dtype))
239
+
240
+ final_hidden_states = final_hidden_states_2d.view(batch_size, sequence_length, hidden_dim)
241
+ return final_hidden_states, router_logits
242
+
243
+ class MiCRoLlama(LlamaPreTrainedModel, GenerationMixin):
244
+ config_class = MiCRoLlamaConfig
245
+ def __init__(self, config: MiCRoLlamaConfig):
246
+ with open(config.config_path, 'r', encoding="utf-8") as file:
247
+ run_config = yaml.load(file.read(), Loader=yaml.FullLoader)
248
+
249
+ self.config: MiCRoLlamaConfig = config
250
+ self.config.torch_dtype = torch.bfloat16
251
+ self.config.use_bfloat16 = True
252
+ self.config._attn_implementation = "flash_attention_2" # {sdpa, flash_attention_2, eager}
253
+ self.config.backbone_num_layers = self.config.num_hidden_layers
254
+ self.config.num_hidden_layers = self.config.num_hidden_layers * run_config["num-experts"]
255
+ self.config.loss_type = "ForCausalLMLoss"
256
+
257
+ super(MiCRoLlama, self).__init__(self.config)
258
+ self.build_model(run_config)
259
+
260
+ def build_model(self, run_config):
261
+
262
+ self.gradient_checkpointing = False
263
+ self.config.num_experts = run_config["num-experts"]
264
+ self.config.use_router = run_config["use-router"]
265
+ self.config.num_experts_per_tok = run_config["top-k-experts"]
266
+ print(f">> Number of Experts per Token: {self.config.num_experts_per_tok}")
267
+ self.config.jitter_noise = run_config["jitter-noise"]
268
+ self.config.loss_method = run_config.get("loss", "all")
269
+ self.config.gradient_checkpointing = run_config.get("gradient-checkpointing", False)
270
+ print(f">> Gradient Checkpointing: {self.config.gradient_checkpointing}")
271
+
272
+ self.run_config = run_config
273
+ self.padding_idx = 2 if "smollm2" in run_config["model"] else 128004
274
+
275
+ # MiCRoLlama model
276
+ self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.padding_idx)
277
+ self.layers = nn.ModuleList([MiCRoLlamaDecoderLayer(self.config, layer_idx) for layer_idx in range(self.config.backbone_num_layers)])
278
+ self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
279
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
280
+ self.final_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
281
+
282
+ if "model" not in run_config["trainable"]:
283
+ print(">> Freezing Model Except Routing Gates")
284
+ for param in self.parameters():
285
+ param.requires_grad = False
286
+
287
+ for layer in self.layers:
288
+ layer: MiCRoLlamaDecoderLayer
289
+ for param in layer.gate.parameters():
290
+ param.requires_grad = True
291
+
292
+ if "experts-router" not in run_config["trainable"]:
293
+ print(">> Freezing Routing Gates")
294
+ for layer in self.layers:
295
+ layer: MiCRoLlamaDecoderLayer
296
+ for param in layer.gate.parameters():
297
+ param.requires_grad = False
298
+
299
+
300
+
301
+ def forward(self,
302
+ input_ids: torch.LongTensor = None,
303
+ attention_mask: Optional[torch.Tensor] = None,
304
+ position_ids: Optional[torch.LongTensor] = None,
305
+ experts_ablate: Optional[List[str]] = None,
306
+ routing_weights: Optional[torch.LongTensor] = None,
307
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
308
+ inputs_embeds: Optional[torch.FloatTensor] = None,
309
+ labels: Optional[torch.LongTensor] = None,
310
+ use_cache: Optional[bool] = None,
311
+ output_attentions: Optional[bool] = None,
312
+ output_hidden_states: Optional[bool] = None,
313
+ return_dict: Optional[bool] = None,
314
+ cache_position: Optional[torch.LongTensor] = None,
315
+ logits_to_keep: Union[int, torch.Tensor] = 0,
316
+ **kwargs: Unpack[FlashAttentionKwargs],
317
+ ):
318
+
319
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
320
+ output_hidden_states = (
321
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
322
+ )
323
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
324
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
325
+
326
+ if (input_ids is None) ^ (inputs_embeds is not None):
327
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
328
+
329
+ if self.gradient_checkpointing and self.training and use_cache:
330
+ logger.warning_once(
331
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
332
+ )
333
+ use_cache = False
334
+
335
+ if inputs_embeds is None:
336
+ inputs_embeds = self.embed_tokens(input_ids)
337
+
338
+ if use_cache and past_key_values is None:
339
+ past_key_values = DynamicCache()
340
+
341
+ if cache_position is None:
342
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
343
+ cache_position = torch.arange(
344
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
345
+ )
346
+
347
+ if position_ids is None:
348
+ position_ids = cache_position.unsqueeze(0)
349
+
350
+ causal_mask = self._update_causal_mask(
351
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
352
+ )
353
+
354
+ hidden_states = inputs_embeds
355
+
356
+ # create position embeddings to be shared across the decoder layers
357
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
358
+
359
+ # decoder layers
360
+ all_hidden_states = () if output_hidden_states else None
361
+ all_self_attns = () if output_attentions else None
362
+
363
+ all_routing_weights = ()
364
+
365
+ for decoder_layer in self.layers:
366
+ if output_hidden_states:
367
+ all_hidden_states += (hidden_states,)
368
+
369
+ if self.gradient_checkpointing and self.training and False:
370
+ layer_outputs, router_logits = self._gradient_checkpointing_func(
371
+ decoder_layer.__call__,
372
+ hidden_states,
373
+ routing_weights,
374
+ causal_mask,
375
+ position_ids,
376
+ experts_ablate,
377
+ past_key_values,
378
+ output_attentions,
379
+ use_cache,
380
+ cache_position,
381
+ position_embeddings,
382
+ )
383
+ else:
384
+ layer_outputs, router_logits = decoder_layer(
385
+ hidden_states,
386
+ routing_weights=routing_weights,
387
+ attention_mask=causal_mask,
388
+ position_ids=position_ids,
389
+ ablate=experts_ablate,
390
+ past_key_value=past_key_values,
391
+ output_attentions=output_attentions,
392
+ use_cache=use_cache,
393
+ cache_position=cache_position,
394
+ position_embeddings=position_embeddings,
395
+ **kwargs,
396
+ )
397
+
398
+ hidden_states = layer_outputs
399
+
400
+ if output_attentions:
401
+ all_self_attns += (layer_outputs[1],)
402
+
403
+ all_routing_weights += (router_logits,)
404
+
405
+ hidden_states = self.final_norm(hidden_states)
406
+
407
+ # add hidden states from the last decoder layer
408
+ if output_hidden_states:
409
+ all_hidden_states += (hidden_states,)
410
+
411
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
412
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
413
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
414
+
415
+ loss = None
416
+ if labels is not None:
417
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
418
+
419
+ if not return_dict:
420
+ output = (logits,) + (past_key_values, all_hidden_states, all_self_attns, all_routing_weights) if use_cache else (logits, all_hidden_states, all_self_attns, all_routing_weights)
421
+ return (loss,) + output if loss is not None else output
422
+
423
+ return CausalLMOutputWithPast(
424
+ loss=loss,
425
+ logits=logits,
426
+ past_key_values=past_key_values if use_cache else None,
427
+ hidden_states=all_hidden_states,
428
+ attentions=all_self_attns,
429
+ routing_weights=all_routing_weights,
430
+ )
431
+
432
+ def _update_causal_mask(
433
+ self,
434
+ attention_mask: torch.Tensor,
435
+ input_tensor: torch.Tensor,
436
+ cache_position: torch.Tensor,
437
+ past_key_values: Cache,
438
+ output_attentions: bool,
439
+ ):
440
+ if self.config._attn_implementation == "flash_attention_2":
441
+ if attention_mask is not None and 0.0 in attention_mask:
442
+ return attention_mask
443
+ return None
444
+
445
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
446
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
447
+ # to infer the attention mask.
448
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
449
+ using_static_cache = isinstance(past_key_values, StaticCache)
450
+
451
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
452
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
453
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
454
+ attention_mask,
455
+ inputs_embeds=input_tensor,
456
+ past_key_values_length=past_seen_tokens,
457
+ is_training=self.training,
458
+ ):
459
+ return None
460
+
461
+ dtype, device = input_tensor.dtype, input_tensor.device
462
+ min_dtype = torch.finfo(dtype).min
463
+ sequence_length = input_tensor.shape[1]
464
+ if using_static_cache:
465
+ target_length = past_key_values.get_max_length()
466
+ else:
467
+ target_length = (
468
+ attention_mask.shape[-1]
469
+ if isinstance(attention_mask, torch.Tensor)
470
+ else past_seen_tokens + sequence_length + 1
471
+ )
472
+
473
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
474
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
475
+ attention_mask,
476
+ sequence_length=sequence_length,
477
+ target_length=target_length,
478
+ dtype=dtype,
479
+ device=device,
480
+ min_dtype=min_dtype,
481
+ cache_position=cache_position,
482
+ batch_size=input_tensor.shape[0],
483
+ )
484
+
485
+ if (
486
+ self.config._attn_implementation == "sdpa"
487
+ and attention_mask is not None
488
+ and attention_mask.device.type == "cuda"
489
+ and not output_attentions
490
+ ):
491
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
492
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
493
+ # Details: https://github.com/pytorch/pytorch/issues/110213
494
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
495
+
496
+ return causal_mask
497
+
498
+ def load_pretrained(self, model_name):
499
+ base_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
500
+ self.lm_head.load_state_dict(base_model.lm_head.state_dict())
501
+ self.embed_tokens.load_state_dict(base_model.get_input_embeddings().state_dict())
502
+ self.rotary_emb.load_state_dict(base_model.model.rotary_emb.state_dict())
503
+ self.final_norm.load_state_dict(base_model.model.norm.state_dict())
504
+ for layer_idx, layer in enumerate(self.layers):
505
+ base_model_layer = base_model.model.layers[layer_idx].state_dict()
506
+ for expert in layer.experts:
507
+ expert.load_state_dict(base_model_layer)
508
+
509
+ def prepare_inputs_for_generation(
510
+ self,
511
+ input_ids,
512
+ past_key_values=None,
513
+ attention_mask=None,
514
+ inputs_embeds=None,
515
+ cache_position=None,
516
+ position_ids=None,
517
+ experts_ablate=None,
518
+ use_cache=True,
519
+ num_logits_to_keep=None,
520
+ **kwargs,
521
+ ):
522
+
523
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
524
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
525
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
526
+ if past_key_values is not None:
527
+ if inputs_embeds is not None: # Exception 1
528
+ input_ids = input_ids[:, -cache_position.shape[0] :]
529
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
530
+ input_ids = input_ids[:, cache_position]
531
+
532
+ if attention_mask is not None and position_ids is None:
533
+ # create position_ids on the fly for batch generation
534
+ position_ids = attention_mask.long().cumsum(-1) - 1
535
+ position_ids.masked_fill_(attention_mask == 0, 1)
536
+ if past_key_values:
537
+ position_ids = position_ids[:, -input_ids.shape[1] :]
538
+
539
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
540
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
541
+
542
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
543
+ if inputs_embeds is not None and cache_position[0] == 0:
544
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
545
+ else:
546
+ # The clone here is for the same reason as for `position_ids`.
547
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
548
+
549
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
550
+ if model_inputs["inputs_embeds"] is not None:
551
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
552
+ device = model_inputs["inputs_embeds"].device
553
+ else:
554
+ batch_size, sequence_length = model_inputs["input_ids"].shape
555
+ device = model_inputs["input_ids"].device
556
+
557
+ dtype = self.lm_head.weight.dtype
558
+ min_dtype = torch.finfo(dtype).min
559
+
560
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
561
+ attention_mask,
562
+ sequence_length=sequence_length,
563
+ target_length=past_key_values.get_max_length(),
564
+ dtype=dtype,
565
+ device=device,
566
+ min_dtype=min_dtype,
567
+ cache_position=cache_position,
568
+ batch_size=batch_size,
569
+ )
570
+
571
+ if num_logits_to_keep is not None:
572
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
573
+
574
+ model_inputs.update(
575
+ {
576
+ "experts_ablate": experts_ablate,
577
+ "position_ids": position_ids,
578
+ "cache_position": cache_position,
579
+ "past_key_values": past_key_values,
580
+ "use_cache": use_cache,
581
+ "attention_mask": attention_mask,
582
+ }
583
+ )
584
+ return model_inputs
585
+
586
+
587
+ AutoConfig.register("micro_llama", MiCRoLlamaConfig)
588
+ AutoModelForCausalLM.register(MiCRoLlamaConfig, MiCRoLlama)
models/micro_moe_llama.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union, List, Callable
2
+ import logging
3
+ import yaml
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.distributed as dist
9
+
10
+ from transformers import LlamaConfig, AutoModelForCausalLM, AutoConfig
11
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
12
+ from transformers.models.llama.modeling_llama import (
13
+ LlamaRotaryEmbedding,
14
+ LlamaRMSNorm,
15
+ LlamaMLP,
16
+ LlamaAttention,
17
+ LlamaForCausalLM,
18
+ LlamaPreTrainedModel,
19
+ GenerationMixin,
20
+ apply_rotary_pos_emb,
21
+ eager_attention_forward,
22
+
23
+ )
24
+ from transformers.modeling_layers import GradientCheckpointingLayer
25
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
26
+ from transformers.cache_utils import Cache, StaticCache, DynamicCache
27
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
28
+ from transformers.processing_utils import Unpack
29
+ from transformers.utils import is_torchdynamo_compiling
30
+ from transformers.activations import ACT2FN
31
+ from models.modules import CausalLMOutputWithPast
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ def keep_alive_zero(model):
36
+ z = 0.0
37
+ for p in model.parameters():
38
+ if p.requires_grad:
39
+ # one scalar per param to avoid heavy sums
40
+ z = z + (p.view(-1)[0] * 0.0)
41
+ return z
42
+
43
+ class MiCRoLlamaMoEConfig(LlamaConfig):
44
+ model_type = "micro_llama_moe"
45
+ def __init__(self, *args, **kwargs):
46
+ super().__init__(*args, **kwargs)
47
+ self.num_experts = kwargs.get("num_experts", 4)
48
+ self.use_router = kwargs.get("use_router", True)
49
+ self.num_experts_per_tok = kwargs.get("num_experts_per_tok", 2)
50
+ self.jitter_noise = kwargs.get("jitter_noise", 0.0)
51
+ self.loss_method = kwargs.get("loss_method", "all")
52
+ self.config_path = kwargs.get("config_path", None)
53
+
54
+ def _prepare_4d_causal_attention_mask_with_cache_position(
55
+ attention_mask: torch.Tensor,
56
+ sequence_length: int,
57
+ target_length: int,
58
+ dtype: torch.dtype,
59
+ device: torch.device,
60
+ min_dtype: float,
61
+ cache_position: torch.Tensor,
62
+ batch_size: int,
63
+ ):
64
+ """
65
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
66
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
67
+
68
+ Args:
69
+ attention_mask (`torch.Tensor`):
70
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
71
+ sequence_length (`int`):
72
+ The sequence length being processed.
73
+ target_length (`int`):
74
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
75
+ dtype (`torch.dtype`):
76
+ The dtype to use for the 4D attention mask.
77
+ device (`torch.device`):
78
+ The device to plcae the 4D attention mask on.
79
+ min_dtype (`float`):
80
+ The minimum value representable with the dtype `dtype`.
81
+ cache_position (`torch.Tensor`):
82
+ Indices depicting the position of the input sequence tokens in the sequence.
83
+ batch_size (`torch.Tensor`):
84
+ Batch size.
85
+ """
86
+ if attention_mask is not None and attention_mask.dim() == 4:
87
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
88
+ causal_mask = attention_mask
89
+ else:
90
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
91
+ if sequence_length != 1:
92
+ causal_mask = torch.triu(causal_mask, diagonal=1)
93
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
94
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
95
+ if attention_mask is not None:
96
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
97
+ mask_length = attention_mask.shape[-1]
98
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
99
+ padding_mask = padding_mask == 0
100
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
101
+ padding_mask, min_dtype
102
+ )
103
+
104
+ return causal_mask
105
+
106
+ class DummyModule(nn.Module):
107
+ def __init__(self):
108
+ super().__init__()
109
+ def forward(self, x):
110
+ return x
111
+
112
+ class LlamaSparseMiCRoMoEBlock(nn.Module):
113
+ """
114
+ This implementation is
115
+ strictly equivalent to standard MoE with full capacity (no
116
+ dropped tokens). It's faster since it formulates MoE operations
117
+ in terms of block-sparse operations to accommodate imbalanced
118
+ assignments of tokens to experts, whereas standard MoE either
119
+ (1) drop tokens at the cost of reduced performance or (2) set
120
+ capacity factor to number of experts and thus waste computation
121
+ and memory on padding.
122
+ """
123
+
124
+ def __init__(self, config):
125
+ super().__init__()
126
+ self.hidden_dim = config.hidden_size
127
+ self.ffn_dim = config.intermediate_size
128
+ self.num_experts = config.num_experts
129
+ self.top_k = config.num_experts_per_tok
130
+ self.use_router = config.use_router
131
+ self.ablate = config.ablate
132
+
133
+ # gating
134
+ self.gate = nn.Sequential(
135
+ nn.Linear(self.hidden_dim, self.hidden_dim, bias=False),
136
+ nn.Linear(self.hidden_dim, self.num_experts, bias=False)
137
+ )
138
+
139
+ self.experts = nn.ModuleList([LlamaMLP(config) for _ in range(self.num_experts)])
140
+
141
+ self.dummy = DummyModule()
142
+
143
+ # Jitter parameters
144
+ self.jitter_noise = config.jitter_noise
145
+
146
+ def forward(self, hidden_states: torch.Tensor, routing_weights: Optional[torch.Tensor] = None) -> torch.Tensor:
147
+ """ """
148
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
149
+ if self.training and self.jitter_noise > 0:
150
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
151
+ hidden_states = hidden_states.view(-1, hidden_dim)
152
+
153
+ if self.use_router:
154
+ router_logits = self.gate(hidden_states)
155
+ if "logic" in self.ablate:
156
+ router_logits[..., 0] = -torch.inf
157
+ if "social" in self.ablate:
158
+ router_logits[..., 1] = -torch.inf
159
+ if "world" in self.ablate:
160
+ router_logits[..., 2] = -torch.inf
161
+ if "language" in self.ablate:
162
+ router_logits[..., 3] = -torch.inf
163
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
164
+ else:
165
+ routing_weights = routing_weights.reshape(-1, 4).float()
166
+ router_logits = routing_weights
167
+ # router_logits: (batch * sequence_length, n_experts)
168
+
169
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
170
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
171
+ # we cast back to the input dtype
172
+ routing_weights = routing_weights.to(hidden_states.dtype)
173
+
174
+ final_hidden_states = torch.zeros(
175
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
176
+ )
177
+
178
+ H_up = self.experts[0].up_proj.out_features
179
+ Y_up = hidden_states.new_zeros((batch_size, sequence_length, self.num_experts, H_up))
180
+
181
+
182
+ # One hot encode the selected experts to create an expert mask
183
+ # this will be used to easily index which expert is going to be sollicitated
184
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
185
+
186
+ expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
187
+ for expert_idx in expert_hitted:
188
+ expert_layer = self.experts[expert_idx]
189
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
190
+ # Index the correct hidden states and compute the expert hidden state for
191
+ # the current expert. We need to make sure to multiply the output hidden
192
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
193
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
194
+
195
+ # --- Hook to capture up-proj output BEFORE nonlinearity ---
196
+ captured_up = []
197
+ def _up_hook(m, inp, out):
198
+ # out shape: [N_e, H_up]
199
+ captured_up.append(out.detach())
200
+
201
+ h = expert_layer.up_proj.register_forward_hook(_up_hook)
202
+
203
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
204
+ h.remove()
205
+
206
+ # Scatter captured up-proj per-token into Y_up[b, t, expert, :]
207
+ if captured_up:
208
+ up = captured_up[-1] # [N_e, H_up]
209
+ b_idx = top_x // sequence_length
210
+ t_idx = top_x % sequence_length
211
+ # Y_up[b,t,e,:] = up[n,:]
212
+ Y_up[b_idx, t_idx, expert_idx, :] = up
213
+
214
+ # However `index_add_` only support torch tensors for indexing so we'll use
215
+ # the `top_x` tensor here.
216
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
217
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
218
+
219
+ self.dummy(Y_up)
220
+ return final_hidden_states, router_logits
221
+
222
+ class LlamaMiCRoMoEDecoderLayer(GradientCheckpointingLayer):
223
+ def __init__(self, config: MiCRoLlamaMoEConfig, layer_idx: int):
224
+ super().__init__()
225
+ self.hidden_size = config.hidden_size
226
+
227
+ self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
228
+
229
+ self.block_sparse_moe = LlamaSparseMiCRoMoEBlock(config)
230
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
231
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states: torch.Tensor,
236
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
237
+ routing_weights: Optional[torch.Tensor] = None,
238
+ attention_mask: Optional[torch.Tensor] = None,
239
+ position_ids: Optional[torch.LongTensor] = None,
240
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
241
+ cache_position: Optional[torch.LongTensor] = None,
242
+ **kwargs: Unpack[FlashAttentionKwargs],
243
+ ) -> torch.FloatTensor:
244
+ residual = hidden_states
245
+
246
+ hidden_states = self.input_layernorm(hidden_states)
247
+
248
+ # Self Attention
249
+ hidden_states, _ = self.self_attn(
250
+ hidden_states=hidden_states,
251
+ position_embeddings=position_embeddings,
252
+ attention_mask=attention_mask,
253
+ position_ids=position_ids,
254
+ past_key_value=past_key_value,
255
+ cache_position=cache_position,
256
+ **kwargs,
257
+ )
258
+ hidden_states = residual + hidden_states
259
+
260
+ # Fully Connected
261
+ residual = hidden_states
262
+ hidden_states = self.post_attention_layernorm(hidden_states)
263
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states, routing_weights)
264
+ hidden_states = residual + hidden_states
265
+
266
+ return hidden_states, router_logits
267
+
268
+
269
+ class MiCRoLlamaMoE(LlamaPreTrainedModel, GenerationMixin):
270
+ config_class = MiCRoLlamaMoEConfig
271
+ def __init__(self, config):
272
+ with open(config.config_path, 'r', encoding="utf-8") as file:
273
+ run_config = yaml.load(file.read(), Loader=yaml.FullLoader)
274
+
275
+ self.config: MiCRoLlamaMoEConfig = config
276
+ self.config.torch_dtype = torch.bfloat16
277
+ self.config.use_bfloat16 = True
278
+ self.config._attn_implementation = "flash_attention_2" # {sdpa, flash_attention_2, eager}
279
+ self.config.use_cache = True
280
+ self.config.backbone_num_layers = self.config.num_hidden_layers
281
+ self.config.num_hidden_layers = self.config.num_hidden_layers
282
+ self.config.loss_type = "ForCausalLMLoss"
283
+
284
+ super(MiCRoLlamaMoE, self).__init__(self.config)
285
+ self.build_model(run_config)
286
+
287
+ def build_model(self, run_config):
288
+
289
+ self.config.num_experts = run_config["num-experts"]
290
+ self.config.use_router = run_config["use-router"]
291
+ self.config.num_experts_per_tok = run_config["top-k-experts"]
292
+ print(f">> Top-K Experts Per Token: {self.config.num_experts_per_tok}")
293
+ self.config.jitter_noise = run_config["jitter-noise"]
294
+ self.config.loss_method = run_config.get("loss", "all")
295
+ self.router_aux_loss_coef = run_config["router-aux-loss-coef"]
296
+ self.use_load_balancing = run_config.get("use-load-balancing", False)
297
+
298
+ self.config.gradient_checkpointing = run_config.get("gradient-checkpointing", False)
299
+ self.gradient_checkpointing = self.config.gradient_checkpointing
300
+
301
+ print(f">> Gradient Checkpointing: {self.config.gradient_checkpointing}")
302
+
303
+ self.run_config = run_config
304
+ self.padding_idx = 2 if "smollm2" in run_config["model"] else 128004
305
+
306
+ # LlamaMoE model
307
+ self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.padding_idx)
308
+ self.layers = nn.ModuleList([LlamaMiCRoMoEDecoderLayer(self.config, layer_idx) for layer_idx in range(self.config.backbone_num_layers)])
309
+ self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
310
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
311
+ self.final_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
312
+
313
+ if "model" not in run_config["trainable"]:
314
+ print(">> Freezing Model Except Experts + Routing Gates")
315
+ for param in self.parameters():
316
+ param.requires_grad = False
317
+
318
+ for layer in self.layers:
319
+ layer: LlamaMiCRoMoEDecoderLayer
320
+ for param in layer.block_sparse_moe.parameters():
321
+ param.requires_grad = True
322
+
323
+ if "experts" not in run_config["trainable"]:
324
+ print(">> Freezing Experts")
325
+ for layer in self.layers:
326
+ layer: LlamaMiCRoMoEDecoderLayer
327
+ for param in layer.block_sparse_moe.experts.parameters():
328
+ param.requires_grad = False
329
+
330
+ if "experts-router" not in run_config["trainable"]:
331
+ print(">> Freezing Routing Gates")
332
+ for layer in self.layers:
333
+ layer: LlamaMiCRoMoEDecoderLayer
334
+ for param in layer.block_sparse_moe.gate.parameters():
335
+ param.requires_grad = False
336
+
337
+
338
+ def forward(self,
339
+ input_ids: torch.LongTensor = None,
340
+ attention_mask: Optional[torch.Tensor] = None,
341
+ position_ids: Optional[torch.LongTensor] = None,
342
+ experts_ablate: Optional[List[str]] = None,
343
+ routing_weights: Optional[torch.LongTensor] = None,
344
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
345
+ inputs_embeds: Optional[torch.FloatTensor] = None,
346
+ labels: Optional[torch.LongTensor] = None,
347
+ use_cache: Optional[bool] = None,
348
+ output_attentions: Optional[bool] = None,
349
+ output_hidden_states: Optional[bool] = None,
350
+ return_dict: Optional[bool] = None,
351
+ cache_position: Optional[torch.LongTensor] = None,
352
+ logits_to_keep: Union[int, torch.Tensor] = 0,
353
+ **kwargs: Unpack[FlashAttentionKwargs],
354
+ ):
355
+
356
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
357
+ output_hidden_states = (
358
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
359
+ )
360
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
361
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
362
+
363
+ if (input_ids is None) ^ (inputs_embeds is not None):
364
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
365
+
366
+ if self.gradient_checkpointing and self.training and use_cache:
367
+ logger.warning_once(
368
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
369
+ )
370
+ use_cache = False
371
+
372
+ if inputs_embeds is None:
373
+ inputs_embeds = self.embed_tokens(input_ids)
374
+
375
+ if use_cache and past_key_values is None:
376
+ past_key_values = DynamicCache()
377
+
378
+ if cache_position is None:
379
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
380
+ cache_position = torch.arange(
381
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
382
+ )
383
+
384
+ if position_ids is None:
385
+ position_ids = cache_position.unsqueeze(0)
386
+
387
+ causal_mask = self._update_causal_mask(
388
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
389
+ )
390
+
391
+ hidden_states = inputs_embeds
392
+
393
+ # create position embeddings to be shared across the decoder layers
394
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
395
+
396
+ # decoder layers
397
+ all_hidden_states = () if output_hidden_states else None
398
+ all_self_attns = () if output_attentions else None
399
+
400
+ all_routing_weights = ()
401
+
402
+ for decoder_layer in self.layers:
403
+ if output_hidden_states:
404
+ all_hidden_states += (hidden_states,)
405
+
406
+ if self.gradient_checkpointing and self.training:
407
+ layer_outputs, router_logits = self._gradient_checkpointing_func(
408
+ decoder_layer.__call__,
409
+ hidden_states,
410
+ position_embeddings,
411
+ routing_weights,
412
+ causal_mask,
413
+ position_ids,
414
+ past_key_values,
415
+ cache_position,
416
+ )
417
+ else:
418
+ layer_outputs, router_logits = decoder_layer(
419
+ hidden_states,
420
+ position_embeddings=position_embeddings,
421
+ routing_weights=routing_weights,
422
+ attention_mask=causal_mask,
423
+ position_ids=position_ids,
424
+ past_key_value=past_key_values,
425
+ output_attentions=output_attentions,
426
+ use_cache=use_cache,
427
+ cache_position=cache_position,
428
+ **kwargs,
429
+ )
430
+
431
+ hidden_states = layer_outputs
432
+
433
+ if output_attentions:
434
+ all_self_attns += (layer_outputs[1],)
435
+
436
+ all_routing_weights += (router_logits,)
437
+
438
+ hidden_states = self.final_norm(hidden_states)
439
+
440
+ # add hidden states from the last decoder layer
441
+ if output_hidden_states:
442
+ all_hidden_states += (hidden_states,)
443
+
444
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
445
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
446
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
447
+
448
+ loss = None
449
+ if labels is not None:
450
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
451
+
452
+ loss += keep_alive_zero(self)
453
+
454
+ aux_loss = None
455
+ if self.use_load_balancing:
456
+ aux_loss = load_balancing_loss_func(
457
+ all_routing_weights,
458
+ self.config.num_experts,
459
+ self.config.num_experts_per_tok,
460
+ attention_mask,
461
+ )
462
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
463
+
464
+ if not return_dict:
465
+ output = (logits,) + (past_key_values, all_hidden_states, all_self_attns, all_routing_weights) if use_cache else (logits, all_hidden_states, all_self_attns, all_routing_weights)
466
+ return (loss,) + output if loss is not None else output
467
+
468
+ return CausalLMOutputWithPast(
469
+ loss=loss,
470
+ logits=logits,
471
+ past_key_values=past_key_values if use_cache else None,
472
+ hidden_states=all_hidden_states,
473
+ attentions=all_self_attns,
474
+ routing_weights=all_routing_weights,
475
+ )
476
+
477
+ def _update_causal_mask(
478
+ self,
479
+ attention_mask: torch.Tensor,
480
+ input_tensor: torch.Tensor,
481
+ cache_position: torch.Tensor,
482
+ past_key_values: Cache,
483
+ output_attentions: bool,
484
+ ):
485
+ if self.config._attn_implementation == "flash_attention_2":
486
+ if attention_mask is not None and 0.0 in attention_mask:
487
+ return attention_mask
488
+ return None
489
+
490
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
491
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
492
+ # to infer the attention mask.
493
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
494
+ using_static_cache = isinstance(past_key_values, StaticCache)
495
+
496
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
497
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
498
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
499
+ attention_mask,
500
+ inputs_embeds=input_tensor,
501
+ past_key_values_length=past_seen_tokens,
502
+ is_training=self.training,
503
+ ):
504
+ return None
505
+
506
+ dtype, device = input_tensor.dtype, input_tensor.device
507
+ min_dtype = torch.finfo(dtype).min
508
+ sequence_length = input_tensor.shape[1]
509
+ if using_static_cache:
510
+ target_length = past_key_values.get_max_length()
511
+ else:
512
+ target_length = (
513
+ attention_mask.shape[-1]
514
+ if isinstance(attention_mask, torch.Tensor)
515
+ else past_seen_tokens + sequence_length + 1
516
+ )
517
+
518
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
519
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
520
+ attention_mask,
521
+ sequence_length=sequence_length,
522
+ target_length=target_length,
523
+ dtype=dtype,
524
+ device=device,
525
+ min_dtype=min_dtype,
526
+ cache_position=cache_position,
527
+ batch_size=input_tensor.shape[0],
528
+ )
529
+
530
+ if (
531
+ self.config._attn_implementation == "sdpa"
532
+ and attention_mask is not None
533
+ and attention_mask.device.type == "cuda"
534
+ and not output_attentions
535
+ ):
536
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
537
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
538
+ # Details: https://github.com/pytorch/pytorch/issues/110213
539
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
540
+
541
+ return causal_mask
542
+
543
+ def load_pretrained(self, model_name):
544
+ base_model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
545
+ self.lm_head.load_state_dict(base_model.lm_head.state_dict())
546
+ self.embed_tokens.load_state_dict(base_model.get_input_embeddings().state_dict())
547
+ self.rotary_emb.load_state_dict(base_model.model.rotary_emb.state_dict())
548
+ self.final_norm.load_state_dict(base_model.model.norm.state_dict())
549
+
550
+ for layer_idx, layer in enumerate(self.layers):
551
+
552
+ attn_layer = base_model.model.layers[layer_idx].self_attn.state_dict()
553
+ layer.self_attn.load_state_dict(attn_layer)
554
+
555
+ input_layernorm = base_model.model.layers[layer_idx].input_layernorm.state_dict()
556
+ layer.input_layernorm.load_state_dict(input_layernorm)
557
+
558
+ post_attention_layernorm = base_model.model.layers[layer_idx].post_attention_layernorm.state_dict()
559
+ layer.post_attention_layernorm.load_state_dict(post_attention_layernorm)
560
+
561
+ mlp_model_layer = base_model.model.layers[layer_idx].mlp.state_dict()
562
+ for expert in layer.block_sparse_moe.experts:
563
+ expert.load_state_dict(mlp_model_layer)
564
+
565
+ def prepare_inputs_for_generation(
566
+ self,
567
+ input_ids,
568
+ past_key_values=None,
569
+ attention_mask=None,
570
+ inputs_embeds=None,
571
+ cache_position=None,
572
+ position_ids=None,
573
+ experts_ablate=None,
574
+ use_cache=True,
575
+ num_logits_to_keep=None,
576
+ **kwargs,
577
+ ):
578
+
579
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
580
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
581
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
582
+ if past_key_values is not None:
583
+ if inputs_embeds is not None: # Exception 1
584
+ input_ids = input_ids[:, -cache_position.shape[0] :]
585
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
586
+ input_ids = input_ids[:, cache_position]
587
+
588
+ if attention_mask is not None and position_ids is None:
589
+ # create position_ids on the fly for batch generation
590
+ position_ids = attention_mask.long().cumsum(-1) - 1
591
+ position_ids.masked_fill_(attention_mask == 0, 1)
592
+ if past_key_values:
593
+ position_ids = position_ids[:, -input_ids.shape[1] :]
594
+
595
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
596
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
597
+
598
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
599
+ if inputs_embeds is not None and cache_position[0] == 0:
600
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
601
+ else:
602
+ # The clone here is for the same reason as for `position_ids`.
603
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
604
+
605
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
606
+ if model_inputs["inputs_embeds"] is not None:
607
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
608
+ device = model_inputs["inputs_embeds"].device
609
+ else:
610
+ batch_size, sequence_length = model_inputs["input_ids"].shape
611
+ device = model_inputs["input_ids"].device
612
+
613
+ dtype = self.lm_head.weight.dtype
614
+ min_dtype = torch.finfo(dtype).min
615
+
616
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
617
+ attention_mask,
618
+ sequence_length=sequence_length,
619
+ target_length=past_key_values.get_max_length(),
620
+ dtype=dtype,
621
+ device=device,
622
+ min_dtype=min_dtype,
623
+ cache_position=cache_position,
624
+ batch_size=batch_size,
625
+ )
626
+
627
+ if num_logits_to_keep is not None:
628
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
629
+
630
+ model_inputs.update(
631
+ {
632
+ "experts_ablate": experts_ablate,
633
+ "position_ids": position_ids,
634
+ "cache_position": cache_position,
635
+ "past_key_values": past_key_values,
636
+ "use_cache": use_cache,
637
+ "attention_mask": attention_mask,
638
+ }
639
+ )
640
+ return model_inputs
641
+
642
+ def load_balancing_loss_func(
643
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
644
+ num_experts: Optional[int] = None,
645
+ top_k=2,
646
+ attention_mask: Optional[torch.Tensor] = None,
647
+ ) -> Union[torch.Tensor, int]:
648
+ r"""
649
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
650
+
651
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
652
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
653
+ experts is too unbalanced.
654
+
655
+ Args:
656
+ gate_logits:
657
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
658
+ shape [batch_size X sequence_length, num_experts].
659
+ num_experts:
660
+ Number of experts
661
+ top_k:
662
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
663
+ parameter.
664
+ attention_mask (`torch.Tensor`, *optional*):
665
+ The attention_mask used in forward function
666
+ shape [batch_size X sequence_length] if not None.
667
+
668
+ Returns:
669
+ The auxiliary loss.
670
+ """
671
+ if gate_logits is None or not isinstance(gate_logits, tuple):
672
+ return 0
673
+
674
+ if isinstance(gate_logits, tuple):
675
+ compute_device = gate_logits[0].device
676
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
677
+
678
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
679
+
680
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
681
+
682
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
683
+
684
+ if attention_mask is None:
685
+ # Compute the percentage of tokens routed to each experts
686
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
687
+
688
+ # Compute the average probability of routing to these experts
689
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
690
+ else:
691
+ batch_size, sequence_length = attention_mask.shape
692
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
693
+
694
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
695
+ expert_attention_mask = (
696
+ attention_mask[None, :, :, None, None]
697
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
698
+ .reshape(-1, top_k, num_experts)
699
+ .to(compute_device)
700
+ )
701
+
702
+ # Compute the percentage of tokens routed to each experts
703
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
704
+ expert_attention_mask, dim=0
705
+ )
706
+
707
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
708
+ router_per_expert_attention_mask = (
709
+ attention_mask[None, :, :, None]
710
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
711
+ .reshape(-1, num_experts)
712
+ .to(compute_device)
713
+ )
714
+
715
+ # Compute the average probability of routing to these experts
716
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
717
+ router_per_expert_attention_mask, dim=0
718
+ )
719
+
720
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
721
+ return overall_loss * num_experts
722
+
723
+
724
+ AutoConfig.register("micro_llama_moe", MiCRoLlamaMoEConfig)
725
+ AutoModelForCausalLM.register(MiCRoLlamaMoEConfig, MiCRoLlamaMoE)
models/micro_olmo.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple, Union
2
+
3
+ import yaml
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ from transformers import AutoModelForCausalLM
9
+ from transformers.activations import ACT2FN
10
+ from transformers.cache_utils import Cache, DynamicCache
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
13
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
14
+ # from transformers.modeling_layers import GradientCheckpointingLayer
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast
16
+ from transformers.processing_utils import Unpack
17
+ from transformers.utils import (
18
+ add_start_docstrings,
19
+ add_start_docstrings_to_model_forward,
20
+ is_torch_flex_attn_available,
21
+ logging,
22
+ replace_return_docstrings,
23
+ )
24
+ from transformers.models.olmo2.configuration_olmo2 import Olmo2Config
25
+ from transformers.models.olmo2.modeling_olmo2 import (
26
+ Olmo2RMSNorm,
27
+ Olmo2Attention,
28
+ Olmo2MLP,
29
+ Olmo2DecoderLayer,
30
+ Olmo2RotaryEmbedding,
31
+ Olmo2PreTrainedModel,
32
+ rotate_half,
33
+ apply_rotary_pos_emb,
34
+ repeat_kv,
35
+ eager_attention_forward,
36
+ )
37
+
38
+
39
+ if is_torch_flex_attn_available():
40
+ from torch.nn.attention.flex_attention import BlockMask
41
+
42
+ from models.modules import CausalLMOutputWithPast
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ class MiCRoOLMo2DecoderLayer(nn.Module):
47
+ def __init__(self, config: Olmo2Config, layer_idx: int):
48
+ super().__init__()
49
+ self.hidden_size = config.hidden_size
50
+
51
+ self.num_experts = config.num_experts
52
+ self.top_k = config.num_experts_per_tok
53
+ self.use_router = config.use_router
54
+ self.ablate = config.ablate or []
55
+ self.num_layers = config.backbone_num_layers
56
+ self.layer_idx = layer_idx
57
+ self.jitter_noise = config.jitter_noise
58
+ self.config = config
59
+ self.head_dim = config.hidden_size // config.num_attention_heads
60
+
61
+ if isinstance(self.ablate, str):
62
+ self.ablate = [self.ablate]
63
+
64
+ # gating head
65
+ self.gate = nn.Sequential(
66
+ nn.Linear(self.hidden_size, self.hidden_size, bias=False),
67
+ nn.Linear(self.hidden_size, self.num_experts, bias=False),
68
+ )
69
+
70
+ self.experts = nn.ModuleList([
71
+ Olmo2DecoderLayer(config, layer_idx * self.num_experts + expert_idx)
72
+ for expert_idx in range(self.num_experts)
73
+ ])
74
+
75
+ def forward(
76
+ self,
77
+ hidden_states: torch.Tensor,
78
+ routing_weights: Optional[torch.Tensor] = None,
79
+ attention_mask: Optional[torch.Tensor] = None,
80
+ position_ids: Optional[torch.LongTensor] = None,
81
+ past_key_value: Optional[Cache] = None,
82
+ output_attentions: Optional[bool] = False,
83
+ use_cache: Optional[bool] = False,
84
+ cache_position: Optional[torch.LongTensor] = None,
85
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
86
+ **kwargs,
87
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
88
+
89
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
90
+
91
+ if self.training and self.jitter_noise > 0:
92
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
93
+
94
+ if self.use_router:
95
+ router_logits = self.gate(hidden_states)
96
+ if "logic" in self.ablate:
97
+ router_logits[..., 0] = -torch.inf
98
+ if "social" in self.ablate:
99
+ router_logits[..., 1] = -torch.inf
100
+ if "world" in self.ablate:
101
+ router_logits[..., 2] = -torch.inf
102
+ if "language" in self.ablate:
103
+ router_logits[..., 3] = -torch.inf
104
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
105
+ else:
106
+ if len(routing_weights.shape) == 2:
107
+ routing_weights = routing_weights.unsqueeze(1).tile((1,sequence_length,1)).float()
108
+ else:
109
+ routing_weights = routing_weights.float()
110
+ router_logits = routing_weights
111
+
112
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
113
+ routing_weights /= (routing_weights.sum(dim=-1, keepdim=True) + 1e-9)
114
+
115
+ # we cast back to the input dtype
116
+ routing_weights = routing_weights.to(hidden_states.dtype)
117
+
118
+ # We'll accumulate outputs here
119
+ final_hidden_states = torch.zeros_like(hidden_states)
120
+
121
+ # Flatten final_hidden_states to [batch_size * seq_len, hidden_dim]
122
+ # so we can do a 2D "index_add_" at the end of each loop.
123
+ final_hidden_states_2d = final_hidden_states.view(-1, hidden_dim)
124
+
125
+ # One hot encode the selected experts to create an expert mask
126
+ # this will be used to easily index which expert is going to be sollicitated
127
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
128
+ #^ [batch_size, seq_len, top_k, num_experts]
129
+
130
+ # Loop over all available experts in the model and perform the computation on each expert
131
+ for expert_idx in range(self.num_experts):
132
+ expert_layer: Olmo2DecoderLayer = self.experts[expert_idx]
133
+ batch_indices, seq_indices, top_k_indices = torch.where(expert_mask[..., expert_idx])
134
+
135
+ if not self.training and sequence_length == 1 and batch_indices.numel() == 0:
136
+ if past_key_value is not None:
137
+
138
+ input_shape = hidden_states.shape[:-1]
139
+ hidden_shape = (*input_shape, -1, self.head_dim)
140
+
141
+ key_states = expert_layer.self_attn.k_proj(hidden_states)
142
+ key_states = expert_layer.self_attn.k_norm(key_states).view(hidden_shape).transpose(1, 2)
143
+ value_states = expert_layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
144
+
145
+
146
+ cos, sin = position_embeddings
147
+ _, key_states = apply_rotary_pos_emb(key_states, key_states, cos, sin)
148
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
149
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
150
+ past_key_value.update(key_states, value_states, self.layer_idx * self.num_experts + expert_idx, cache_kwargs)
151
+
152
+ continue
153
+
154
+ current_hidden_states = expert_layer(
155
+ hidden_states=hidden_states,
156
+ attention_mask=attention_mask,
157
+ position_ids=position_ids,
158
+ past_key_value=past_key_value,
159
+ output_attentions=output_attentions,
160
+ use_cache=use_cache,
161
+ cache_position=cache_position,
162
+ position_embeddings=position_embeddings,
163
+ **kwargs,
164
+ )[0]
165
+
166
+ flat_idx = batch_indices * sequence_length + seq_indices
167
+ expert_weights = routing_weights[batch_indices, seq_indices, top_k_indices].unsqueeze(-1)
168
+ current_hidden_states = current_hidden_states[batch_indices, seq_indices] * expert_weights
169
+
170
+ final_hidden_states_2d.index_add_(0, flat_idx, current_hidden_states.to(hidden_states.dtype))
171
+
172
+ final_hidden_states = final_hidden_states_2d.view(batch_size, sequence_length, hidden_dim)
173
+ return final_hidden_states, router_logits
174
+
175
+ class MiCRoOLMo(Olmo2PreTrainedModel, GenerationMixin):
176
+ """
177
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Olmo2DecoderLayer`]
178
+
179
+ Args:
180
+ config: Olmo2Config
181
+ """
182
+
183
+ _tied_weights_keys = ["lm_head.weight"]
184
+ _tp_plan = {"lm_head": "colwise_rep"}
185
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
186
+
187
+ def __init__(self, config: Olmo2Config):
188
+ with open(config.config_path, 'r', encoding="utf-8") as file:
189
+ run_config = yaml.load(file.read(), Loader=yaml.FullLoader)
190
+
191
+ self.config: Olmo2Config = config
192
+ self.config.torch_dtype = torch.bfloat16
193
+ self.config.use_bfloat16 = True
194
+ self.config._attn_implementation = "flash_attention_2" # {sdpa, flash_attention_2, eager}
195
+ self.config.use_cache = True
196
+ self.config.backbone_num_layers = self.config.num_hidden_layers
197
+ self.config.num_hidden_layers = self.config.num_hidden_layers * run_config["num-experts"]
198
+ self.config.loss_type = "ForCausalLMLoss"
199
+
200
+ self.padding_idx = config.pad_token_id
201
+ self.vocab_size = config.vocab_size
202
+
203
+ self.gradient_checkpointing = False
204
+ super().__init__(config)
205
+ self.padding_idx = config.pad_token_id
206
+ self.vocab_size = config.vocab_size
207
+
208
+ self.build_model(run_config)
209
+
210
+ # Initialize weights and apply final processing
211
+ self.post_init()
212
+
213
+ def get_input_embeddings(self):
214
+ return self.embed_tokens
215
+
216
+ def set_input_embeddings(self, value):
217
+ self.embed_tokens = value
218
+
219
+ def get_output_embeddings(self):
220
+ return self.lm_head
221
+
222
+ def set_output_embeddings(self, value):
223
+ self.lm_head = value
224
+
225
+ def build_model(self, run_config):
226
+ self.gradient_checkpointing = False
227
+ self.config.num_experts = run_config["num-experts"]
228
+ self.config.use_router = run_config["use-router"]
229
+ self.config.num_experts_per_tok = run_config["top-k-experts"]
230
+ self.config.jitter_noise = run_config["jitter-noise"]
231
+ self.config.loss_method = run_config.get("loss", "all")
232
+
233
+ self.run_config = run_config
234
+ # Qwen2 model
235
+ self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.padding_idx)
236
+ self.layers = nn.ModuleList([MiCRoOLMo2DecoderLayer(self.config, layer_idx) for layer_idx in range(self.config.backbone_num_layers)])
237
+ self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
238
+ self.rotary_emb = Olmo2RotaryEmbedding(config=self.config)
239
+ self.norm = Olmo2RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
240
+
241
+ # Freeze Model
242
+ for param in self.parameters():
243
+ param.requires_grad = False
244
+
245
+ # Unfreeze Modules
246
+ if "reasoners" in run_config["trainable"]:
247
+ print(">> Unfreezing Reasoning Modules")
248
+ for layer in self.layers:
249
+ layer: MiCRoOLMo2DecoderLayer
250
+ for param in layer.experts.parameters():
251
+ param.requires_grad = True
252
+
253
+ if "model" in run_config["trainable"]:
254
+ print(">> Unfreezing Model")
255
+ for param in self.layers.parameters():
256
+ param.requires_grad = True
257
+
258
+ for param in self.lm_head.parameters():
259
+ param.requires_grad = True
260
+
261
+ for param in self.rotary_emb.parameters():
262
+ param.requires_grad = True
263
+
264
+ for param in self.norm.parameters():
265
+ param.requires_grad = True
266
+
267
+ for param in self.embed_tokens.parameters():
268
+ param.requires_grad = True
269
+
270
+ for layer in self.layers:
271
+ for param in layer.gate.parameters():
272
+ param.requires_grad = False
273
+
274
+
275
+ if "experts-router" in run_config["trainable"]:
276
+ print(">> Unfreezing Experts Router")
277
+ for layer in self.layers:
278
+ for param in layer.gate.parameters():
279
+ param.requires_grad = True
280
+
281
+ def forward(
282
+ self,
283
+ input_ids: Optional[torch.LongTensor] = None,
284
+ attention_mask: Optional[torch.Tensor] = None,
285
+ position_ids: Optional[torch.LongTensor] = None,
286
+ routing_weights: Optional[torch.LongTensor] = None,
287
+ past_key_values: Optional[Cache] = None,
288
+ inputs_embeds: Optional[torch.FloatTensor] = None,
289
+ labels: Optional[torch.LongTensor] = None,
290
+ use_cache: Optional[bool] = None,
291
+ output_attentions: Optional[bool] = None,
292
+ output_hidden_states: Optional[bool] = None,
293
+ return_dict: Optional[bool] = None,
294
+ cache_position: Optional[torch.LongTensor] = None,
295
+ logits_to_keep: Union[int, torch.Tensor] = 0,
296
+ **kwargs: Unpack[FlashAttentionKwargs],
297
+ ) -> BaseModelOutputWithPast:
298
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
299
+ output_hidden_states = (
300
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
301
+ )
302
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
303
+
304
+ if (input_ids is None) ^ (inputs_embeds is not None):
305
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
306
+
307
+ if self.gradient_checkpointing and self.training and use_cache:
308
+ logger.warning_once(
309
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
310
+ )
311
+ use_cache = False
312
+
313
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
314
+ if not isinstance(past_key_values, (type(None), Cache)):
315
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
316
+
317
+ if inputs_embeds is None:
318
+ inputs_embeds = self.embed_tokens(input_ids)
319
+
320
+ if use_cache and past_key_values is None:
321
+ past_key_values = DynamicCache()
322
+
323
+ if cache_position is None:
324
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
325
+ cache_position = torch.arange(
326
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
327
+ )
328
+
329
+ if position_ids is None:
330
+ position_ids = cache_position.unsqueeze(0)
331
+
332
+ causal_mask = self._update_causal_mask(
333
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
334
+ )
335
+
336
+ hidden_states = inputs_embeds
337
+
338
+ # create position embeddings to be shared across the decoder layers
339
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
340
+
341
+ # decoder layers
342
+ all_hidden_states = () if output_hidden_states else None
343
+ all_self_attns = () if output_attentions else None
344
+ all_routing_weights = ()
345
+
346
+ for decoder_layer in self.layers:
347
+ if output_hidden_states:
348
+ all_hidden_states += (hidden_states,)
349
+
350
+ layer_outputs, router_logits = decoder_layer(
351
+ hidden_states,
352
+ routing_weights=routing_weights,
353
+ attention_mask=causal_mask,
354
+ position_ids=position_ids,
355
+ past_key_value=past_key_values,
356
+ output_attentions=output_attentions,
357
+ use_cache=use_cache,
358
+ cache_position=cache_position,
359
+ position_embeddings=position_embeddings,
360
+ **kwargs,
361
+ # **flash_attn_kwargs,
362
+ )
363
+
364
+ hidden_states = layer_outputs
365
+
366
+ # if output_attentions:
367
+ # all_self_attns += (layer_outputs[1],)
368
+
369
+ all_routing_weights += (router_logits,)
370
+
371
+ hidden_states = self.norm(hidden_states)
372
+
373
+ # add hidden states from the last decoder layer
374
+ if output_hidden_states:
375
+ all_hidden_states += (hidden_states,)
376
+
377
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
378
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
379
+
380
+ loss = None
381
+ if labels is not None:
382
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
383
+
384
+ return CausalLMOutputWithPast(
385
+ loss=loss,
386
+ logits=logits,
387
+ past_key_values=past_key_values if use_cache else None,
388
+ hidden_states=all_hidden_states,
389
+ attentions=all_self_attns,
390
+ routing_weights=all_routing_weights,
391
+ )
392
+
393
+ def load_pretrained(self, model_name):
394
+ base_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
395
+ self.lm_head.load_state_dict(base_model.lm_head.state_dict())
396
+ self.embed_tokens.load_state_dict(base_model.get_input_embeddings().state_dict())
397
+ self.rotary_emb.load_state_dict(base_model.model.rotary_emb.state_dict())
398
+ self.norm.load_state_dict(base_model.model.norm.state_dict())
399
+ for layer_idx, layer in enumerate(self.layers):
400
+ base_model_layer = base_model.model.layers[layer_idx].state_dict()
401
+ for expert in layer.experts:
402
+ expert.load_state_dict(base_model_layer)
403
+
404
+ def _update_causal_mask(
405
+ self,
406
+ attention_mask: Union[torch.Tensor, "BlockMask"],
407
+ input_tensor: torch.Tensor,
408
+ cache_position: torch.Tensor,
409
+ past_key_values: Cache,
410
+ output_attentions: bool = False,
411
+ ):
412
+ if self.config._attn_implementation == "flash_attention_2":
413
+ if attention_mask is not None and (attention_mask == 0.0).any():
414
+ return attention_mask
415
+ return None
416
+ if self.config._attn_implementation == "flex_attention":
417
+ if isinstance(attention_mask, torch.Tensor):
418
+ attention_mask = make_flex_block_causal_mask(attention_mask)
419
+ return attention_mask
420
+
421
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
422
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
423
+ # to infer the attention mask.
424
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
425
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
426
+
427
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
428
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
429
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
430
+ attention_mask,
431
+ inputs_embeds=input_tensor,
432
+ past_key_values_length=past_seen_tokens,
433
+ is_training=self.training,
434
+ ):
435
+ return None
436
+
437
+ dtype = input_tensor.dtype
438
+ sequence_length = input_tensor.shape[1]
439
+ if using_compilable_cache:
440
+ target_length = past_key_values.get_max_cache_shape()
441
+ else:
442
+ target_length = (
443
+ attention_mask.shape[-1]
444
+ if isinstance(attention_mask, torch.Tensor)
445
+ else past_seen_tokens + sequence_length + 1
446
+ )
447
+
448
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
449
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
450
+ attention_mask,
451
+ sequence_length=sequence_length,
452
+ target_length=target_length,
453
+ dtype=dtype,
454
+ cache_position=cache_position,
455
+ batch_size=input_tensor.shape[0],
456
+ )
457
+
458
+ if (
459
+ self.config._attn_implementation == "sdpa"
460
+ and attention_mask is not None
461
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
462
+ and not output_attentions
463
+ ):
464
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
465
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
466
+ # Details: https://github.com/pytorch/pytorch/issues/110213
467
+ min_dtype = torch.finfo(dtype).min
468
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
469
+
470
+ return causal_mask
471
+
472
+ @staticmethod
473
+ def _prepare_4d_causal_attention_mask_with_cache_position(
474
+ attention_mask: torch.Tensor,
475
+ sequence_length: int,
476
+ target_length: int,
477
+ dtype: torch.dtype,
478
+ cache_position: torch.Tensor,
479
+ batch_size: int,
480
+ **kwargs,
481
+ ):
482
+ """
483
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
484
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
485
+
486
+ Args:
487
+ attention_mask (`torch.Tensor`):
488
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
489
+ `(batch_size, 1, query_length, key_value_length)`.
490
+ sequence_length (`int`):
491
+ The sequence length being processed.
492
+ target_length (`int`):
493
+ The target length: when generating with static cache, the mask should be as long as the static cache,
494
+ to account for the 0 padding, the part of the cache that is not filled yet.
495
+ dtype (`torch.dtype`):
496
+ The dtype to use for the 4D attention mask.
497
+ cache_position (`torch.Tensor`):
498
+ Indices depicting the position of the input sequence tokens in the sequence.
499
+ batch_size (`torch.Tensor`):
500
+ Batch size.
501
+ """
502
+ if attention_mask is not None and attention_mask.dim() == 4:
503
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
504
+ causal_mask = attention_mask
505
+ else:
506
+ min_dtype = torch.finfo(dtype).min
507
+ causal_mask = torch.full(
508
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
509
+ )
510
+ if sequence_length != 1:
511
+ causal_mask = torch.triu(causal_mask, diagonal=1)
512
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
513
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
514
+ if attention_mask is not None:
515
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
516
+ mask_length = attention_mask.shape[-1]
517
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
518
+ causal_mask.device
519
+ )
520
+ padding_mask = padding_mask == 0
521
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
522
+ padding_mask, min_dtype
523
+ )
524
+
525
+ return causal_mask
526
+
527
+
528
+ __all__ = ["MiCRoOLMo"]
models/modules.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, List, Union
3
+
4
+ import torch
5
+ from transformers.modeling_outputs import ModelOutput
6
+
7
+
8
+ @dataclass
9
+ class CausalLMOutputWithPast(ModelOutput):
10
+ """
11
+ Base class for causal language model (or autoregressive) outputs.
12
+
13
+ Args:
14
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
15
+ Language modeling loss (for next-token prediction).
16
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
17
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
18
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
19
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
20
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
21
+
22
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
23
+ `past_key_values` input) to speed up sequential decoding.
24
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
25
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
26
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
27
+
28
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
29
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
30
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
31
+ sequence_length)`.
32
+
33
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
34
+ heads.
35
+ """
36
+
37
+ loss: Optional[torch.FloatTensor] = None
38
+ logits: torch.FloatTensor = None
39
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
40
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
41
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
42
+ routing_weights: Optional[Tuple[torch.FloatTensor, ...]] = None
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio>=4.44.0
2
+ plotly>=5.22.0
3
+ pandas>=2.2.0
router_backend.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # router_backend.py
2
+ """
3
+ Plug your real model routing function here.
4
+
5
+ Implement the function:
6
+ get_expert_routing(model_id: str, prompt: str) -> list[float] | dict[str, float] | tuple[float, float, float, float]
7
+
8
+ It must return 4 values (percentages) corresponding to the experts:
9
+ ["Language", "Logic", "Social", "World"]
10
+
11
+ Example return formats:
12
+ - [12.5, 45.0, 22.5, 20.0]
13
+ - {"Language": 12.5, "Logic": 45.0, "Social": 22.5, "World": 20.0}
14
+ - (12.5, 45.0, 22.5, 20.0)
15
+ """
16
+ import torch
17
+ import numpy as np
18
+ import torch.nn.functional as F
19
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
20
+ from typing import Union, Dict, List, Tuple
21
+
22
+ from models.micro_olmo import MiCRoOLMo
23
+ from models.micro_llama import MiCRoLlama
24
+ from models.micro_moe_llama import MiCRoLlamaMoE
25
+
26
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
27
+
28
+ def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dict[str, str]]]) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]:
29
+
30
+ model, tokenizer = build_model(model_id, hf_token)
31
+
32
+ if isinstance(prompt, str):
33
+ generation, routing_weights = generate_continuation(model, tokenizer, prompt)
34
+ elif isinstance(prompt, dict):
35
+ generation = None
36
+ routing_weights = get_routing_weights(model, tokenizer, [prompt])
37
+
38
+ model_routing_percentages = aggregate_routing_weights(routing_weights)
39
+
40
+ if generation is not None:
41
+ print(f"Generation:\n{generation}")
42
+
43
+ return {
44
+ "Language": float(model_routing_percentages[3]),
45
+ "Logic": float(model_routing_percentages[0]),
46
+ "Social": float(model_routing_percentages[1]),
47
+ "World": float(model_routing_percentages[2]),
48
+ }, generation
49
+
50
+ def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]:
51
+ return {
52
+ # MiCRo-Llama
53
+ "micro-llama-1b": ("bkhmsi/micro-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama),
54
+ "micro-llama-3b": ("bkhmsi/micro-llama-3b", "meta-llama/Llama-3.2-3B-Instruct", MiCRoLlama),
55
+ "micro-llama-1b-dpo": ("bkhmsi/micro-llama-1b-dpo", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama),
56
+
57
+ # MiCRo-MoE-Llama
58
+ "micro-moe-llama-1b": ("bkhmsi/micro-moe-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlamaMoE),
59
+
60
+ # MiCRo-OLMo
61
+ "micro-olmo": ("bkhmsi/micro-olmo-1b", "allenai/OLMo-2-0425-1B-Instruct", MiCRoOLMo),
62
+
63
+ # MiCRo-SmolLM2
64
+ "micro-smollm2-135m": ("bkhmsi/micro-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlama),
65
+ "micro-smollm2-360m": ("bkhmsi/micro-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlama),
66
+
67
+ # MiCRo-MoE-SmolLM2
68
+ "micro-moe-smollm2-135m": ("bkhmsi/micro-moe-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlamaMoE),
69
+ "micro-moe-smollm2-360m": ("bkhmsi/micro-moe-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlamaMoE),
70
+ }.get(model_name, (model_name, model_name, AutoModelForCausalLM))
71
+
72
+ def aggregate_routing_weights(routing_weights):
73
+ experts = ["Logic", "Social", "World", "Language"]
74
+ expert_token_model = np.zeros((len(experts)), dtype=int)
75
+ expert_layer_token = np.zeros((routing_weights.shape[0], len(experts)), dtype=int)
76
+ num_layers = routing_weights.shape[0]
77
+
78
+ for layer_idx in range(num_layers):
79
+ for token_idx in range(len(routing_weights[layer_idx])):
80
+ expert_idx = routing_weights[layer_idx][token_idx].argmax()
81
+ if layer_idx >= 2 and layer_idx < num_layers - 2:
82
+ expert_token_model[expert_idx] += 1
83
+ expert_layer_token[layer_idx][expert_idx] += 1
84
+ return expert_token_model, expert_layer_token
85
+
86
+ def generate_continuation(model,
87
+ tokenizer,
88
+ prompts,
89
+ max_tokens=1024,
90
+ use_cache=True,
91
+ return_routing_weights=True
92
+ ):
93
+
94
+ if isinstance(prompts, str):
95
+ prompts = [{"role": "user", "content": prompts}]
96
+
97
+ tokenizer.padding_side = "left"
98
+ inputs = tokenizer.apply_chat_template([
99
+ prompt for prompt in prompts
100
+ ], return_tensors="pt", padding=True, add_generation_prompt=True).to(DEVICE)
101
+
102
+ attention_mask = torch.ones_like(inputs)
103
+ attention_mask[inputs == tokenizer.pad_token_id] = 0
104
+
105
+ outputs = model.generate(
106
+ input_ids=inputs,
107
+ attention_mask=attention_mask,
108
+ max_new_tokens=max_tokens,
109
+ use_cache=use_cache,
110
+ stop_strings=["</s>","<|eot_id|>", "<|im_start|>user"],
111
+ tokenizer=tokenizer,
112
+ pad_token_id=tokenizer.pad_token_id,
113
+ temperature=0,
114
+ top_p=1.0,
115
+ do_sample=False,
116
+ )
117
+
118
+ if return_routing_weights:
119
+ attention_mask = torch.ones_like(outputs)
120
+ attention_mask[outputs == tokenizer.pad_token_id] = 0
121
+ model_output = model(input_ids=outputs, attention_mask=attention_mask)
122
+ torch.cuda.empty_cache()
123
+
124
+ routing_weights = model_output.routing_weights
125
+ routing_weights = np.concatenate([
126
+ F.softmax(rw, dim=-1)[:, inputs.shape[1]:].detach().float().cpu().numpy()
127
+ for rw in routing_weights
128
+ ])
129
+
130
+ else:
131
+ routing_weights = None
132
+
133
+ inputs_text = tokenizer.batch_decode(inputs, skip_special_tokens=False)
134
+
135
+ generations = []
136
+ for i, output in enumerate(outputs):
137
+ decoded_output = tokenizer.decode(output, skip_special_tokens=False)
138
+ decoded_output = decoded_output.replace(inputs_text[i], "")
139
+ decoded_output = decoded_output.replace(tokenizer.pad_token, "").strip()
140
+ decoded_output = decoded_output.replace("<|end_of_text|>", "").strip()
141
+ decoded_output = decoded_output.replace("<|endoftext|>", "").strip()
142
+ decoded_output = decoded_output.replace("<|eot_id|>", "").strip()
143
+ decoded_output = decoded_output.replace("\n<|im_start|>user", "").strip()
144
+ generations.append(decoded_output)
145
+
146
+ return (generations, routing_weights) if return_routing_weights else generations
147
+
148
+ def get_routing_weights(model, tokenizer, prompts, apply_chat_template=True):
149
+ """
150
+ Get routing weights for the given prompts using the model.
151
+ Args:
152
+ model: The MiCRoLlama or MiCRoOLMo model.
153
+ tokenizer: The tokenizer for the model.
154
+ prompts: A string or list of dictionaries containing the prompts.
155
+ Returns:
156
+ routing_weights: A list of routing weights for each layer.
157
+ """
158
+
159
+ tokenizer.padding_side = "left"
160
+ if apply_chat_template:
161
+ if isinstance(prompts, str):
162
+ prompts = [{"role": "user", "content": prompts}]
163
+
164
+ inputs = tokenizer.apply_chat_template([
165
+ prompt for prompt in prompts
166
+ ], return_tensors="pt", padding=True).to(DEVICE)
167
+
168
+ input_without_response = tokenizer.apply_chat_template([
169
+ prompt[:-1] for prompt in prompts
170
+ ], return_tensors="pt", padding=True,
171
+ ).to(DEVICE)
172
+ else:
173
+ inputs = tokenizer(prompts[0] + prompts[1], return_tensors="pt", padding=True).input_ids.to(DEVICE)
174
+ input_without_response = tokenizer(prompts[0], return_tensors="pt", padding=True).input_ids.to(DEVICE)
175
+
176
+ attention_mask = torch.ones_like(inputs)
177
+ attention_mask[inputs == tokenizer.pad_token_id] = 0
178
+
179
+ model_output = model(input_ids=inputs, attention_mask=attention_mask)
180
+
181
+ routing_weights = model_output.routing_weights
182
+ routing_weights = np.stack([F.softmax(rw, dim=-1).detach().float().cpu().numpy() for rw in routing_weights], axis=0).squeeze()
183
+
184
+ offset = len(input_without_response[0])-1
185
+ routing_weights = routing_weights[:, offset:-1]
186
+
187
+ return routing_weights
188
+
189
+ def build_model(model_id: str, hf_token: str, use_cache: bool = True):
190
+
191
+ model_path, base_model, model_class = get_model_path(model_id)
192
+
193
+ model_config = AutoConfig.from_pretrained(base_model, use_auth_token=hf_token)
194
+ model_config.config_path = f"configs/{model_id}.yml"
195
+
196
+ model_config.torch_dtype = torch.bfloat16
197
+ model_config.use_bfloat16 = True
198
+ model_config._attn_implementation = "flash_attention_2"
199
+ model_config.use_cache = use_cache
200
+ model_config.ablate = []
201
+
202
+ tokenizer = AutoTokenizer.from_pretrained(base_model, use_auth_token=hf_token)
203
+ tokenizer.padding_side = "left"
204
+
205
+ if "llama" in model_id:
206
+ tokenizer.pad_token_id = 128004
207
+ if "olmo" in model_id:
208
+ tokenizer.pad_token_id = 100277
209
+ tokenizer.add_special_tokens({'additional_special_tokens': ['<|assistant|>']})
210
+ elif "smollm2" in model_id:
211
+ tokenizer.pad_token_id = 2
212
+ else:
213
+ tokenizer.pad_token_id = 128004
214
+
215
+ if "olmo" in model_id:
216
+ model_config.vocab_size = len(tokenizer)
217
+
218
+ model = model_class.from_pretrained(model_path, config=model_config, low_cpu_mem_usage=True)
219
+
220
+ model.to(DEVICE)
221
+ model = model.bfloat16()
222
+ model.eval()
223
+ return model, tokenizer