# router_backend.py """ Plug your real model routing function here. Implement the function: get_expert_routing(model_id: str, prompt: str) -> list[float] | dict[str, float] | tuple[float, float, float, float] It must return 4 values (percentages) corresponding to the experts: ["Language", "Logic", "Social", "World"] Example return formats: - [12.5, 45.0, 22.5, 20.0] - {"Language": 12.5, "Logic": 45.0, "Social": 22.5, "World": 20.0} - (12.5, 45.0, 22.5, 20.0) """ import torch import pathlib import numpy as np import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from typing import Union, Dict, List, Tuple from models.micro_olmo import MiCRoOLMo from models.micro_llama import MiCRoLlama from models.micro_moe_llama import MiCRoLlamaMoE DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dict[str, str]]], ablations: List[str] = None) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]: model, tokenizer = build_model(model_id, hf_token, ablations=ablations) if isinstance(prompt, str): generation, routing_weights = generate_continuation(model, tokenizer, prompt) generation = generation[0] if type(generation) is list else generation elif isinstance(prompt, list): generation = None routing_weights = get_routing_weights(model, tokenizer, [prompt]) model_routing_percentages = aggregate_routing_weights(routing_weights)[0] print(model_routing_percentages) if generation is not None: print(f"Generation:\n{generation}") return { "Language": float(model_routing_percentages[3]), "Logic": float(model_routing_percentages[0]), "Social": float(model_routing_percentages[1]), "World": float(model_routing_percentages[2]), }, generation def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]: return { # MiCRo-Llama "micro-llama-1b": ("bkhmsi/micro-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama), "micro-llama-3b": ("bkhmsi/micro-llama-3b", "meta-llama/Llama-3.2-3B-Instruct", MiCRoLlama), "micro-llama-1b-dpo": ("bkhmsi/micro-llama-1b-dpo", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama), # MiCRo-MoE-Llama "micro-moe-llama-1b": ("bkhmsi/micro-moe-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlamaMoE), # MiCRo-OLMo "micro-olmo": ("bkhmsi/micro-olmo-1b", "allenai/OLMo-2-0425-1B-Instruct", MiCRoOLMo), # MiCRo-SmolLM2 "micro-smollm2-135m": ("bkhmsi/micro-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlama), "micro-smollm2-360m": ("bkhmsi/micro-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlama), # MiCRo-MoE-SmolLM2 "micro-moe-smollm2-135m": ("bkhmsi/micro-moe-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlamaMoE), "micro-moe-smollm2-360m": ("bkhmsi/micro-moe-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlamaMoE), }.get(model_name, (model_name, model_name, AutoModelForCausalLM)) def aggregate_routing_weights(routing_weights): experts = ["Logic", "Social", "World", "Language"] expert_token_model = np.zeros((len(experts)), dtype=int) expert_layer_token = np.zeros((routing_weights.shape[0], len(experts)), dtype=int) num_layers = routing_weights.shape[0] for layer_idx in range(num_layers): for token_idx in range(len(routing_weights[layer_idx])): expert_idx = routing_weights[layer_idx][token_idx].argmax() if layer_idx >= 2 and layer_idx < num_layers - 2: expert_token_model[expert_idx] += 1 expert_layer_token[layer_idx][expert_idx] += 1 return expert_token_model, expert_layer_token def generate_continuation(model, tokenizer, prompts, max_tokens=128, use_cache=True, return_routing_weights=True ): if isinstance(prompts, str): prompts = [{"role": "user", "content": prompts}] tokenizer.padding_side = "left" inputs = tokenizer.apply_chat_template([ prompt for prompt in prompts ], return_tensors="pt", padding=True, add_generation_prompt=True).to(DEVICE) attention_mask = torch.ones_like(inputs) attention_mask[inputs == tokenizer.pad_token_id] = 0 outputs = model.generate( input_ids=inputs, attention_mask=attention_mask, max_new_tokens=max_tokens, use_cache=use_cache, stop_strings=["","<|eot_id|>", "<|im_start|>user", "user"], tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, temperature=0, top_p=1.0, do_sample=False, ) if return_routing_weights: attention_mask = torch.ones_like(outputs) attention_mask[outputs == tokenizer.pad_token_id] = 0 model_output = model(input_ids=outputs, attention_mask=attention_mask) torch.cuda.empty_cache() routing_weights = model_output.routing_weights routing_weights = np.concatenate([ F.softmax(rw, dim=-1)[:, inputs.shape[1]:].detach().float().cpu().numpy() for rw in routing_weights ]) else: routing_weights = None inputs_text = tokenizer.batch_decode(inputs, skip_special_tokens=False) generations = [] for i, output in enumerate(outputs): decoded_output = tokenizer.decode(output, skip_special_tokens=False) decoded_output = decoded_output.replace(inputs_text[i], "") decoded_output = decoded_output.replace(tokenizer.pad_token, "").strip() decoded_output = decoded_output.replace("<|end_of_text|>", "").strip() decoded_output = decoded_output.replace("<|endoftext|>", "").strip() decoded_output = decoded_output.replace("<|eot_id|>", "").strip() decoded_output = decoded_output.replace("\n<|im_start|>user", "").strip() generations.append(decoded_output) return (generations, routing_weights) if return_routing_weights else generations def get_routing_weights(model, tokenizer, prompts, apply_chat_template=True): """ Get routing weights for the given prompts using the model. Args: model: The MiCRoLlama or MiCRoOLMo model. tokenizer: The tokenizer for the model. prompts: A string or list of dictionaries containing the prompts. Returns: routing_weights: A list of routing weights for each layer. """ tokenizer.padding_side = "left" if apply_chat_template: if isinstance(prompts, str): prompts = [{"role": "user", "content": prompts}] inputs = tokenizer.apply_chat_template([ prompt for prompt in prompts ], return_tensors="pt", padding=True).to(DEVICE) input_without_response = tokenizer.apply_chat_template([ prompt[:-1] for prompt in prompts ], return_tensors="pt", padding=True, ).to(DEVICE) else: inputs = tokenizer(prompts[0] + prompts[1], return_tensors="pt", padding=True).input_ids.to(DEVICE) input_without_response = tokenizer(prompts[0], return_tensors="pt", padding=True).input_ids.to(DEVICE) attention_mask = torch.ones_like(inputs) attention_mask[inputs == tokenizer.pad_token_id] = 0 model_output = model(input_ids=inputs, attention_mask=attention_mask) routing_weights = model_output.routing_weights routing_weights = np.stack([F.softmax(rw, dim=-1).detach().float().cpu().numpy() for rw in routing_weights], axis=0).squeeze() offset = len(input_without_response[0])-1 routing_weights = routing_weights[:, offset:-1] return routing_weights def build_model(model_id: str, hf_token: str, ablations: List[str], use_cache: bool = True): model_path, base_model, model_class = get_model_path(model_id) model_config = AutoConfig.from_pretrained(base_model, use_auth_token=hf_token) parent_path = pathlib.Path(__file__).parent model_config.config_path = f"{parent_path}/configs/{model_id.replace('-', '_')}.yml" model_config.torch_dtype = torch.bfloat16 model_config.use_bfloat16 = True model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager} model_config.use_cache = use_cache model_config.ablate = ablations tokenizer = AutoTokenizer.from_pretrained(base_model, use_auth_token=hf_token) tokenizer.padding_side = "left" if "llama" in model_id: tokenizer.pad_token_id = 128004 if "olmo" in model_id: tokenizer.pad_token_id = 100277 tokenizer.add_special_tokens({'additional_special_tokens': ['<|assistant|>']}) elif "smollm2" in model_id: tokenizer.pad_token_id = 2 else: tokenizer.pad_token_id = 128004 if "olmo" in model_id: model_config.vocab_size = len(tokenizer) model = model_class.from_pretrained(model_path, config=model_config, low_cpu_mem_usage=True) model.to(DEVICE) model = model.bfloat16() model.eval() return model, tokenizer