Spaces:
Running
Running
added lineplot
Browse files- app.py +47 -23
- router_backend.py +28 -11
app.py
CHANGED
|
@@ -21,20 +21,6 @@ import plotly.express as px
|
|
| 21 |
import pandas as pd
|
| 22 |
from router_backend import get_expert_routing
|
| 23 |
|
| 24 |
-
# ---- Expected backend adapter ------------------------------------------------
|
| 25 |
-
# Implement your real function in router_backend.py with the following signature:
|
| 26 |
-
# def get_expert_routing(model_id: str, prompt: str) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]
|
| 27 |
-
# It MUST return 4 values that sum to ~100 (percentages) in the fixed order:
|
| 28 |
-
# ["Language", "Logic", "Social", "World"]
|
| 29 |
-
# or a mapping with those keys.
|
| 30 |
-
# try:
|
| 31 |
-
# from router_backend import get_expert_routing # your real backend
|
| 32 |
-
# BACKEND_AVAILABLE = True
|
| 33 |
-
# except Exception as e: # keep error for display if needed
|
| 34 |
-
# BACKEND_AVAILABLE = False
|
| 35 |
-
# _backend_import_error = e
|
| 36 |
-
|
| 37 |
-
|
| 38 |
EXPERTS = ["Language", "Logic", "Social", "World"]
|
| 39 |
|
| 40 |
DEFAULT_MODELS = [
|
|
@@ -83,6 +69,42 @@ def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str:
|
|
| 83 |
return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
|
| 84 |
return user_prompt
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
def route_and_plot(
|
| 87 |
model_choice: str,
|
| 88 |
user_prompt: str,
|
|
@@ -129,16 +151,19 @@ def route_and_plot(
|
|
| 129 |
msg = "Using mock data."
|
| 130 |
vals = _mock_routing(model_id, prompt, seed=seed)
|
| 131 |
generation = None
|
|
|
|
| 132 |
else:
|
| 133 |
try:
|
| 134 |
-
raw, generation = get_expert_routing(model_id, hf_token, prompt, ablations) # <-- your real function
|
| 135 |
vals = _normalize_output(raw)
|
| 136 |
msg = "Routed with real backend."
|
| 137 |
except Exception as e:
|
| 138 |
# fallback to mock on error, but surface message
|
|
|
|
| 139 |
msg = f"Backend error: {e}\nFalling back to mock data."
|
| 140 |
vals = _mock_routing(model_id, prompt, seed=seed)
|
| 141 |
generation = None
|
|
|
|
| 142 |
|
| 143 |
df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
|
| 144 |
colors = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
|
|
@@ -147,18 +172,19 @@ def route_and_plot(
|
|
| 147 |
fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside")
|
| 148 |
fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35)
|
| 149 |
|
|
|
|
|
|
|
| 150 |
status = f"Model: {model_id}<br>{msg}"
|
| 151 |
if generation is None:
|
| 152 |
generation = assistant_prompt
|
| 153 |
|
| 154 |
-
return generation, df, fig, status
|
| 155 |
|
| 156 |
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
| 157 |
gr.Markdown(
|
| 158 |
"""
|
| 159 |
# 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer
|
| 160 |
## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts.
|
| 161 |
-
Paper: [Mixture of Cognitive Reasoners: Modular Reasoning with Brain-Like Specialization](https://arxiv.org/abs/2506.13331)
|
| 162 |
----
|
| 163 |
This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
|
| 164 |
Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure.
|
|
@@ -187,17 +213,15 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 187 |
user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
|
| 188 |
assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...")
|
| 189 |
|
| 190 |
-
# with gr.Row():
|
| 191 |
-
# use_mock = gr.Checkbox(value=True, label="Use mock data (uncheck to call your backend)")
|
| 192 |
-
# seed = gr.Slider(value=0, minimum=0, maximum=10_000, step=1, label="Mock seed")
|
| 193 |
-
|
| 194 |
run = gr.Button("Run Routing", variant="primary")
|
| 195 |
|
| 196 |
generation_output = gr.Textbox(lines=4, label="Generated Response", placeholder="Generated text will appear here...", interactive=False)
|
| 197 |
|
| 198 |
with gr.Row():
|
| 199 |
table = gr.Dataframe(label="Routing Percentages", interactive=False)
|
|
|
|
| 200 |
plot = gr.Plot(label="Bar Plot")
|
|
|
|
| 201 |
|
| 202 |
|
| 203 |
status = gr.Markdown("", label="System Message")
|
|
@@ -205,7 +229,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 205 |
run.click(
|
| 206 |
route_and_plot,
|
| 207 |
inputs=[model_choice, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world],
|
| 208 |
-
outputs=[generation_output, table, plot, status],
|
| 209 |
)
|
| 210 |
|
| 211 |
# example prompts
|
|
@@ -242,7 +266,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 242 |
label="Try these examples:",
|
| 243 |
cache_examples=True,
|
| 244 |
fn=route_and_plot,
|
| 245 |
-
outputs=[generation_output, table, plot, status],
|
| 246 |
)
|
| 247 |
|
| 248 |
if __name__ == "__main__":
|
|
|
|
| 21 |
import pandas as pd
|
| 22 |
from router_backend import get_expert_routing
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
EXPERTS = ["Language", "Logic", "Social", "World"]
|
| 25 |
|
| 26 |
DEFAULT_MODELS = [
|
|
|
|
| 69 |
return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
|
| 70 |
return user_prompt
|
| 71 |
|
| 72 |
+
def plot_lines(arrays):
|
| 73 |
+
names = EXPERTS
|
| 74 |
+
|
| 75 |
+
LINE_COLORS = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
|
| 76 |
+
|
| 77 |
+
LINE_COLORS = {
|
| 78 |
+
name: color for name, color in zip(names, LINE_COLORS)
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Build a tidy DataFrame: columns = index, value, series
|
| 82 |
+
records = []
|
| 83 |
+
|
| 84 |
+
for i, array in enumerate(arrays):
|
| 85 |
+
for name, v in zip(names, array):
|
| 86 |
+
records.append({"index": i+1, "value": v, "series": name})
|
| 87 |
+
|
| 88 |
+
df = pd.DataFrame.from_records(records)
|
| 89 |
+
|
| 90 |
+
fig = px.line(
|
| 91 |
+
df,
|
| 92 |
+
x="index",
|
| 93 |
+
y="value",
|
| 94 |
+
color="series",
|
| 95 |
+
color_discrete_map=LINE_COLORS,
|
| 96 |
+
title="",
|
| 97 |
+
markers=True,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
fig.update_layout(
|
| 101 |
+
xaxis_title="Layer Index",
|
| 102 |
+
yaxis_title="Percentage (%)",
|
| 103 |
+
legend_title="Layer-wise Expert Routing",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return fig
|
| 107 |
+
|
| 108 |
def route_and_plot(
|
| 109 |
model_choice: str,
|
| 110 |
user_prompt: str,
|
|
|
|
| 151 |
msg = "Using mock data."
|
| 152 |
vals = _mock_routing(model_id, prompt, seed=seed)
|
| 153 |
generation = None
|
| 154 |
+
layer_routing = None
|
| 155 |
else:
|
| 156 |
try:
|
| 157 |
+
raw, layer_routing, generation = get_expert_routing(model_id, hf_token, prompt, ablations) # <-- your real function
|
| 158 |
vals = _normalize_output(raw)
|
| 159 |
msg = "Routed with real backend."
|
| 160 |
except Exception as e:
|
| 161 |
# fallback to mock on error, but surface message
|
| 162 |
+
print(f"Backend error: {e}")
|
| 163 |
msg = f"Backend error: {e}\nFalling back to mock data."
|
| 164 |
vals = _mock_routing(model_id, prompt, seed=seed)
|
| 165 |
generation = None
|
| 166 |
+
layer_routing = None
|
| 167 |
|
| 168 |
df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
|
| 169 |
colors = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
|
|
|
|
| 172 |
fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside")
|
| 173 |
fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35)
|
| 174 |
|
| 175 |
+
line_fig = plot_lines(layer_routing) if layer_routing is not None else None
|
| 176 |
+
|
| 177 |
status = f"Model: {model_id}<br>{msg}"
|
| 178 |
if generation is None:
|
| 179 |
generation = assistant_prompt
|
| 180 |
|
| 181 |
+
return generation, df, fig, line_fig, status
|
| 182 |
|
| 183 |
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
| 184 |
gr.Markdown(
|
| 185 |
"""
|
| 186 |
# 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer
|
| 187 |
## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts.
|
|
|
|
| 188 |
----
|
| 189 |
This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
|
| 190 |
Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure.
|
|
|
|
| 213 |
user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
|
| 214 |
assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...")
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
run = gr.Button("Run Routing", variant="primary")
|
| 217 |
|
| 218 |
generation_output = gr.Textbox(lines=4, label="Generated Response", placeholder="Generated text will appear here...", interactive=False)
|
| 219 |
|
| 220 |
with gr.Row():
|
| 221 |
table = gr.Dataframe(label="Routing Percentages", interactive=False)
|
| 222 |
+
|
| 223 |
plot = gr.Plot(label="Bar Plot")
|
| 224 |
+
line_plot = gr.Plot(label="Layer-wise Routing Percentages")
|
| 225 |
|
| 226 |
|
| 227 |
status = gr.Markdown("", label="System Message")
|
|
|
|
| 229 |
run.click(
|
| 230 |
route_and_plot,
|
| 231 |
inputs=[model_choice, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world],
|
| 232 |
+
outputs=[generation_output, table, plot, line_plot, status],
|
| 233 |
)
|
| 234 |
|
| 235 |
# example prompts
|
|
|
|
| 266 |
label="Try these examples:",
|
| 267 |
cache_examples=True,
|
| 268 |
fn=route_and_plot,
|
| 269 |
+
outputs=[generation_output, table, plot, line_plot, status],
|
| 270 |
)
|
| 271 |
|
| 272 |
if __name__ == "__main__":
|
router_backend.py
CHANGED
|
@@ -37,9 +37,26 @@ def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dic
|
|
| 37 |
generation = None
|
| 38 |
routing_weights = get_routing_weights(model, tokenizer, [prompt])
|
| 39 |
|
| 40 |
-
model_routing_percentages = aggregate_routing_weights(routing_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
print(model_routing_percentages)
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
if generation is not None:
|
| 44 |
print(f"Generation:\n{generation}")
|
| 45 |
|
|
@@ -48,7 +65,7 @@ def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dic
|
|
| 48 |
"Logic": float(model_routing_percentages[0]),
|
| 49 |
"Social": float(model_routing_percentages[1]),
|
| 50 |
"World": float(model_routing_percentages[2]),
|
| 51 |
-
}, generation
|
| 52 |
|
| 53 |
def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]:
|
| 54 |
return {
|
|
@@ -75,15 +92,14 @@ def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]:
|
|
| 75 |
def aggregate_routing_weights(routing_weights):
|
| 76 |
experts = ["Logic", "Social", "World", "Language"]
|
| 77 |
expert_token_model = np.zeros((len(experts)), dtype=int)
|
| 78 |
-
expert_layer_token = np.zeros((routing_weights.shape[0]
|
| 79 |
num_layers = routing_weights.shape[0]
|
| 80 |
|
| 81 |
for layer_idx in range(num_layers):
|
| 82 |
for token_idx in range(len(routing_weights[layer_idx])):
|
| 83 |
expert_idx = routing_weights[layer_idx][token_idx].argmax()
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
expert_layer_token[layer_idx][expert_idx] += 1
|
| 87 |
return expert_token_model, expert_layer_token
|
| 88 |
|
| 89 |
def generate_continuation(model,
|
|
@@ -179,7 +195,8 @@ def get_routing_weights(model, tokenizer, prompts, apply_chat_template=True):
|
|
| 179 |
attention_mask = torch.ones_like(inputs)
|
| 180 |
attention_mask[inputs == tokenizer.pad_token_id] = 0
|
| 181 |
|
| 182 |
-
|
|
|
|
| 183 |
|
| 184 |
routing_weights = model_output.routing_weights
|
| 185 |
routing_weights = np.stack([F.softmax(rw, dim=-1).detach().float().cpu().numpy() for rw in routing_weights], axis=0).squeeze()
|
|
@@ -199,8 +216,8 @@ def build_model(model_id: str, hf_token: str, ablations: List[str], use_cache: b
|
|
| 199 |
|
| 200 |
model_config.config_path = f"{parent_path}/configs/{model_id.replace('-', '_')}.yml"
|
| 201 |
|
| 202 |
-
model_config.torch_dtype = torch.
|
| 203 |
-
model_config.use_bfloat16 =
|
| 204 |
model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
|
| 205 |
model_config.use_cache = use_cache
|
| 206 |
model_config.ablate = ablations
|
|
@@ -221,9 +238,9 @@ def build_model(model_id: str, hf_token: str, ablations: List[str], use_cache: b
|
|
| 221 |
if "olmo" in model_id:
|
| 222 |
model_config.vocab_size = len(tokenizer)
|
| 223 |
|
| 224 |
-
model = model_class.from_pretrained(model_path, config=model_config, low_cpu_mem_usage=True)
|
| 225 |
|
| 226 |
model.to(DEVICE)
|
| 227 |
-
model = model.
|
| 228 |
model.eval()
|
| 229 |
return model, tokenizer
|
|
|
|
| 37 |
generation = None
|
| 38 |
routing_weights = get_routing_weights(model, tokenizer, [prompt])
|
| 39 |
|
| 40 |
+
model_routing_percentages, layer_token_routing = aggregate_routing_weights(routing_weights)
|
| 41 |
+
|
| 42 |
+
layer_token_routing = np.array(layer_token_routing)
|
| 43 |
+
num_experts, num_layers = layer_token_routing.shape
|
| 44 |
+
|
| 45 |
print(model_routing_percentages)
|
| 46 |
|
| 47 |
+
layer_token_routing = np.roll(layer_token_routing, shift=1, axis=0)
|
| 48 |
+
|
| 49 |
+
all_layer_routing_percentages = []
|
| 50 |
+
for layer_idx in range(num_layers):
|
| 51 |
+
layer_token_percentages = []
|
| 52 |
+
for expert_idx in range(num_experts):
|
| 53 |
+
percentage = (layer_token_routing[expert_idx][layer_idx] / sum(layer_token_routing[:, layer_idx])) * 100
|
| 54 |
+
layer_token_percentages.append(percentage)
|
| 55 |
+
all_layer_routing_percentages.append(layer_token_percentages)
|
| 56 |
+
|
| 57 |
+
layer_routing_percentages = np.array(all_layer_routing_percentages)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
if generation is not None:
|
| 61 |
print(f"Generation:\n{generation}")
|
| 62 |
|
|
|
|
| 65 |
"Logic": float(model_routing_percentages[0]),
|
| 66 |
"Social": float(model_routing_percentages[1]),
|
| 67 |
"World": float(model_routing_percentages[2]),
|
| 68 |
+
}, layer_routing_percentages, generation
|
| 69 |
|
| 70 |
def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]:
|
| 71 |
return {
|
|
|
|
| 92 |
def aggregate_routing_weights(routing_weights):
|
| 93 |
experts = ["Logic", "Social", "World", "Language"]
|
| 94 |
expert_token_model = np.zeros((len(experts)), dtype=int)
|
| 95 |
+
expert_layer_token = np.zeros((len(experts), routing_weights.shape[0]), dtype=int)
|
| 96 |
num_layers = routing_weights.shape[0]
|
| 97 |
|
| 98 |
for layer_idx in range(num_layers):
|
| 99 |
for token_idx in range(len(routing_weights[layer_idx])):
|
| 100 |
expert_idx = routing_weights[layer_idx][token_idx].argmax()
|
| 101 |
+
expert_token_model[expert_idx] += 1
|
| 102 |
+
expert_layer_token[expert_idx][layer_idx] += 1
|
|
|
|
| 103 |
return expert_token_model, expert_layer_token
|
| 104 |
|
| 105 |
def generate_continuation(model,
|
|
|
|
| 195 |
attention_mask = torch.ones_like(inputs)
|
| 196 |
attention_mask[inputs == tokenizer.pad_token_id] = 0
|
| 197 |
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
model_output = model(input_ids=inputs, attention_mask=attention_mask)
|
| 200 |
|
| 201 |
routing_weights = model_output.routing_weights
|
| 202 |
routing_weights = np.stack([F.softmax(rw, dim=-1).detach().float().cpu().numpy() for rw in routing_weights], axis=0).squeeze()
|
|
|
|
| 216 |
|
| 217 |
model_config.config_path = f"{parent_path}/configs/{model_id.replace('-', '_')}.yml"
|
| 218 |
|
| 219 |
+
model_config.torch_dtype = torch.float16
|
| 220 |
+
model_config.use_bfloat16 = False
|
| 221 |
model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
|
| 222 |
model_config.use_cache = use_cache
|
| 223 |
model_config.ablate = ablations
|
|
|
|
| 238 |
if "olmo" in model_id:
|
| 239 |
model_config.vocab_size = len(tokenizer)
|
| 240 |
|
| 241 |
+
model = model_class.from_pretrained(model_path, config=model_config, low_cpu_mem_usage=True, dtype=torch.float16)
|
| 242 |
|
| 243 |
model.to(DEVICE)
|
| 244 |
+
model = model.half()
|
| 245 |
model.eval()
|
| 246 |
return model, tokenizer
|