bkhmsi commited on
Commit
e63a1d1
·
1 Parent(s): ae072a3

added lineplot

Browse files
Files changed (2) hide show
  1. app.py +47 -23
  2. router_backend.py +28 -11
app.py CHANGED
@@ -21,20 +21,6 @@ import plotly.express as px
21
  import pandas as pd
22
  from router_backend import get_expert_routing
23
 
24
- # ---- Expected backend adapter ------------------------------------------------
25
- # Implement your real function in router_backend.py with the following signature:
26
- # def get_expert_routing(model_id: str, prompt: str) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]
27
- # It MUST return 4 values that sum to ~100 (percentages) in the fixed order:
28
- # ["Language", "Logic", "Social", "World"]
29
- # or a mapping with those keys.
30
- # try:
31
- # from router_backend import get_expert_routing # your real backend
32
- # BACKEND_AVAILABLE = True
33
- # except Exception as e: # keep error for display if needed
34
- # BACKEND_AVAILABLE = False
35
- # _backend_import_error = e
36
-
37
-
38
  EXPERTS = ["Language", "Logic", "Social", "World"]
39
 
40
  DEFAULT_MODELS = [
@@ -83,6 +69,42 @@ def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str:
83
  return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
84
  return user_prompt
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def route_and_plot(
87
  model_choice: str,
88
  user_prompt: str,
@@ -129,16 +151,19 @@ def route_and_plot(
129
  msg = "Using mock data."
130
  vals = _mock_routing(model_id, prompt, seed=seed)
131
  generation = None
 
132
  else:
133
  try:
134
- raw, generation = get_expert_routing(model_id, hf_token, prompt, ablations) # <-- your real function
135
  vals = _normalize_output(raw)
136
  msg = "Routed with real backend."
137
  except Exception as e:
138
  # fallback to mock on error, but surface message
 
139
  msg = f"Backend error: {e}\nFalling back to mock data."
140
  vals = _mock_routing(model_id, prompt, seed=seed)
141
  generation = None
 
142
 
143
  df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
144
  colors = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
@@ -147,18 +172,19 @@ def route_and_plot(
147
  fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside")
148
  fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35)
149
 
 
 
150
  status = f"Model: {model_id}<br>{msg}"
151
  if generation is None:
152
  generation = assistant_prompt
153
 
154
- return generation, df, fig, status
155
 
156
  with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
157
  gr.Markdown(
158
  """
159
  # 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer
160
  ## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts.
161
- Paper: [Mixture of Cognitive Reasoners: Modular Reasoning with Brain-Like Specialization](https://arxiv.org/abs/2506.13331)
162
  ----
163
  This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
164
  Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure.
@@ -187,17 +213,15 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
187
  user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
188
  assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...")
189
 
190
- # with gr.Row():
191
- # use_mock = gr.Checkbox(value=True, label="Use mock data (uncheck to call your backend)")
192
- # seed = gr.Slider(value=0, minimum=0, maximum=10_000, step=1, label="Mock seed")
193
-
194
  run = gr.Button("Run Routing", variant="primary")
195
 
196
  generation_output = gr.Textbox(lines=4, label="Generated Response", placeholder="Generated text will appear here...", interactive=False)
197
 
198
  with gr.Row():
199
  table = gr.Dataframe(label="Routing Percentages", interactive=False)
 
200
  plot = gr.Plot(label="Bar Plot")
 
201
 
202
 
203
  status = gr.Markdown("", label="System Message")
@@ -205,7 +229,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
205
  run.click(
206
  route_and_plot,
207
  inputs=[model_choice, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world],
208
- outputs=[generation_output, table, plot, status],
209
  )
210
 
211
  # example prompts
@@ -242,7 +266,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
242
  label="Try these examples:",
243
  cache_examples=True,
244
  fn=route_and_plot,
245
- outputs=[generation_output, table, plot, status],
246
  )
247
 
248
  if __name__ == "__main__":
 
21
  import pandas as pd
22
  from router_backend import get_expert_routing
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  EXPERTS = ["Language", "Logic", "Social", "World"]
25
 
26
  DEFAULT_MODELS = [
 
69
  return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
70
  return user_prompt
71
 
72
+ def plot_lines(arrays):
73
+ names = EXPERTS
74
+
75
+ LINE_COLORS = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
76
+
77
+ LINE_COLORS = {
78
+ name: color for name, color in zip(names, LINE_COLORS)
79
+ }
80
+
81
+ # Build a tidy DataFrame: columns = index, value, series
82
+ records = []
83
+
84
+ for i, array in enumerate(arrays):
85
+ for name, v in zip(names, array):
86
+ records.append({"index": i+1, "value": v, "series": name})
87
+
88
+ df = pd.DataFrame.from_records(records)
89
+
90
+ fig = px.line(
91
+ df,
92
+ x="index",
93
+ y="value",
94
+ color="series",
95
+ color_discrete_map=LINE_COLORS,
96
+ title="",
97
+ markers=True,
98
+ )
99
+
100
+ fig.update_layout(
101
+ xaxis_title="Layer Index",
102
+ yaxis_title="Percentage (%)",
103
+ legend_title="Layer-wise Expert Routing",
104
+ )
105
+
106
+ return fig
107
+
108
  def route_and_plot(
109
  model_choice: str,
110
  user_prompt: str,
 
151
  msg = "Using mock data."
152
  vals = _mock_routing(model_id, prompt, seed=seed)
153
  generation = None
154
+ layer_routing = None
155
  else:
156
  try:
157
+ raw, layer_routing, generation = get_expert_routing(model_id, hf_token, prompt, ablations) # <-- your real function
158
  vals = _normalize_output(raw)
159
  msg = "Routed with real backend."
160
  except Exception as e:
161
  # fallback to mock on error, but surface message
162
+ print(f"Backend error: {e}")
163
  msg = f"Backend error: {e}\nFalling back to mock data."
164
  vals = _mock_routing(model_id, prompt, seed=seed)
165
  generation = None
166
+ layer_routing = None
167
 
168
  df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
169
  colors = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
 
172
  fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside")
173
  fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35)
174
 
175
+ line_fig = plot_lines(layer_routing) if layer_routing is not None else None
176
+
177
  status = f"Model: {model_id}<br>{msg}"
178
  if generation is None:
179
  generation = assistant_prompt
180
 
181
+ return generation, df, fig, line_fig, status
182
 
183
  with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
184
  gr.Markdown(
185
  """
186
  # 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer
187
  ## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts.
 
188
  ----
189
  This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
190
  Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure.
 
213
  user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
214
  assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...")
215
 
 
 
 
 
216
  run = gr.Button("Run Routing", variant="primary")
217
 
218
  generation_output = gr.Textbox(lines=4, label="Generated Response", placeholder="Generated text will appear here...", interactive=False)
219
 
220
  with gr.Row():
221
  table = gr.Dataframe(label="Routing Percentages", interactive=False)
222
+
223
  plot = gr.Plot(label="Bar Plot")
224
+ line_plot = gr.Plot(label="Layer-wise Routing Percentages")
225
 
226
 
227
  status = gr.Markdown("", label="System Message")
 
229
  run.click(
230
  route_and_plot,
231
  inputs=[model_choice, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world],
232
+ outputs=[generation_output, table, plot, line_plot, status],
233
  )
234
 
235
  # example prompts
 
266
  label="Try these examples:",
267
  cache_examples=True,
268
  fn=route_and_plot,
269
+ outputs=[generation_output, table, plot, line_plot, status],
270
  )
271
 
272
  if __name__ == "__main__":
router_backend.py CHANGED
@@ -37,9 +37,26 @@ def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dic
37
  generation = None
38
  routing_weights = get_routing_weights(model, tokenizer, [prompt])
39
 
40
- model_routing_percentages = aggregate_routing_weights(routing_weights)[0]
 
 
 
 
41
  print(model_routing_percentages)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  if generation is not None:
44
  print(f"Generation:\n{generation}")
45
 
@@ -48,7 +65,7 @@ def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dic
48
  "Logic": float(model_routing_percentages[0]),
49
  "Social": float(model_routing_percentages[1]),
50
  "World": float(model_routing_percentages[2]),
51
- }, generation
52
 
53
  def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]:
54
  return {
@@ -75,15 +92,14 @@ def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]:
75
  def aggregate_routing_weights(routing_weights):
76
  experts = ["Logic", "Social", "World", "Language"]
77
  expert_token_model = np.zeros((len(experts)), dtype=int)
78
- expert_layer_token = np.zeros((routing_weights.shape[0], len(experts)), dtype=int)
79
  num_layers = routing_weights.shape[0]
80
 
81
  for layer_idx in range(num_layers):
82
  for token_idx in range(len(routing_weights[layer_idx])):
83
  expert_idx = routing_weights[layer_idx][token_idx].argmax()
84
- if layer_idx >= 2 and layer_idx < num_layers - 2:
85
- expert_token_model[expert_idx] += 1
86
- expert_layer_token[layer_idx][expert_idx] += 1
87
  return expert_token_model, expert_layer_token
88
 
89
  def generate_continuation(model,
@@ -179,7 +195,8 @@ def get_routing_weights(model, tokenizer, prompts, apply_chat_template=True):
179
  attention_mask = torch.ones_like(inputs)
180
  attention_mask[inputs == tokenizer.pad_token_id] = 0
181
 
182
- model_output = model(input_ids=inputs, attention_mask=attention_mask)
 
183
 
184
  routing_weights = model_output.routing_weights
185
  routing_weights = np.stack([F.softmax(rw, dim=-1).detach().float().cpu().numpy() for rw in routing_weights], axis=0).squeeze()
@@ -199,8 +216,8 @@ def build_model(model_id: str, hf_token: str, ablations: List[str], use_cache: b
199
 
200
  model_config.config_path = f"{parent_path}/configs/{model_id.replace('-', '_')}.yml"
201
 
202
- model_config.torch_dtype = torch.bfloat16
203
- model_config.use_bfloat16 = True
204
  model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
205
  model_config.use_cache = use_cache
206
  model_config.ablate = ablations
@@ -221,9 +238,9 @@ def build_model(model_id: str, hf_token: str, ablations: List[str], use_cache: b
221
  if "olmo" in model_id:
222
  model_config.vocab_size = len(tokenizer)
223
 
224
- model = model_class.from_pretrained(model_path, config=model_config, low_cpu_mem_usage=True)
225
 
226
  model.to(DEVICE)
227
- model = model.bfloat16()
228
  model.eval()
229
  return model, tokenizer
 
37
  generation = None
38
  routing_weights = get_routing_weights(model, tokenizer, [prompt])
39
 
40
+ model_routing_percentages, layer_token_routing = aggregate_routing_weights(routing_weights)
41
+
42
+ layer_token_routing = np.array(layer_token_routing)
43
+ num_experts, num_layers = layer_token_routing.shape
44
+
45
  print(model_routing_percentages)
46
 
47
+ layer_token_routing = np.roll(layer_token_routing, shift=1, axis=0)
48
+
49
+ all_layer_routing_percentages = []
50
+ for layer_idx in range(num_layers):
51
+ layer_token_percentages = []
52
+ for expert_idx in range(num_experts):
53
+ percentage = (layer_token_routing[expert_idx][layer_idx] / sum(layer_token_routing[:, layer_idx])) * 100
54
+ layer_token_percentages.append(percentage)
55
+ all_layer_routing_percentages.append(layer_token_percentages)
56
+
57
+ layer_routing_percentages = np.array(all_layer_routing_percentages)
58
+
59
+
60
  if generation is not None:
61
  print(f"Generation:\n{generation}")
62
 
 
65
  "Logic": float(model_routing_percentages[0]),
66
  "Social": float(model_routing_percentages[1]),
67
  "World": float(model_routing_percentages[2]),
68
+ }, layer_routing_percentages, generation
69
 
70
  def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]:
71
  return {
 
92
  def aggregate_routing_weights(routing_weights):
93
  experts = ["Logic", "Social", "World", "Language"]
94
  expert_token_model = np.zeros((len(experts)), dtype=int)
95
+ expert_layer_token = np.zeros((len(experts), routing_weights.shape[0]), dtype=int)
96
  num_layers = routing_weights.shape[0]
97
 
98
  for layer_idx in range(num_layers):
99
  for token_idx in range(len(routing_weights[layer_idx])):
100
  expert_idx = routing_weights[layer_idx][token_idx].argmax()
101
+ expert_token_model[expert_idx] += 1
102
+ expert_layer_token[expert_idx][layer_idx] += 1
 
103
  return expert_token_model, expert_layer_token
104
 
105
  def generate_continuation(model,
 
195
  attention_mask = torch.ones_like(inputs)
196
  attention_mask[inputs == tokenizer.pad_token_id] = 0
197
 
198
+ with torch.no_grad():
199
+ model_output = model(input_ids=inputs, attention_mask=attention_mask)
200
 
201
  routing_weights = model_output.routing_weights
202
  routing_weights = np.stack([F.softmax(rw, dim=-1).detach().float().cpu().numpy() for rw in routing_weights], axis=0).squeeze()
 
216
 
217
  model_config.config_path = f"{parent_path}/configs/{model_id.replace('-', '_')}.yml"
218
 
219
+ model_config.torch_dtype = torch.float16
220
+ model_config.use_bfloat16 = False
221
  model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
222
  model_config.use_cache = use_cache
223
  model_config.ablate = ablations
 
238
  if "olmo" in model_id:
239
  model_config.vocab_size = len(tokenizer)
240
 
241
+ model = model_class.from_pretrained(model_path, config=model_config, low_cpu_mem_usage=True, dtype=torch.float16)
242
 
243
  model.to(DEVICE)
244
+ model = model.half()
245
  model.eval()
246
  return model, tokenizer