# /// script # requires-python = "==3.10" # dependencies = [ # "numpy", # "kernels", # "torch" # ] # /// import torch from collections import namedtuple from pathlib import Path from kernels import get_kernel, get_local_kernel from kernels.utils import build_variant # Make reproducible torch.manual_seed(42) if torch.cuda.is_available(): torch.cuda.manual_seed(42) repo_root = Path(__file__).resolve().parent megablocks = None try: # Prefer the published kernel megablocks = get_kernel("shisa-ai/megablocks-hip") print("MegaBlocks kernel downloaded successfully.") except FileNotFoundError: # Fall back to the locally staged ROCm build (produced by build.py) variant = build_variant() local_pkg = repo_root / "build" / variant print(f"Hub build missing for {variant}; falling back to {local_pkg}") megablocks = get_local_kernel(local_pkg, "megablocks") print("MegaBlocks kernel loaded from local build.") model = megablocks.layers.MegaBlocksMoeMLP() model.experts = namedtuple("Experts", ["gate_up_proj", "gate_down_proj", "down_proj", "hidden_size"]) print("MegaBlocksMoeMLP instance created successfully.") # Config ne, hs, isz = 128, 1152, 3072 # Router with proper initialization device = "cuda" if torch.cuda.is_available() else "cpu" model.router = torch.nn.Linear(hs, ne, device=device) torch.nn.init.kaiming_uniform_(model.router.weight) # Expert layers with realistic weights e = model.experts e.gate_up_proj = torch.nn.Parameter(torch.randn(ne, hs, isz, device=device) * 0.02) e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device=device)) e.down_proj = torch.nn.Parameter(torch.randn(ne, 1536, hs, device=device) * 0.02) e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device=device)) e.hidden_size = hs print("Expert layers initialized successfully.") # Test with normalized input x = torch.randn(1, 1, hs, device=device) * 0.1 output, expert_weights = model(x) print("Model forward pass completed successfully.") print(f"Output shape: {output.shape}") print(f"Output range: [{output.min():.3f}, {output.max():.3f}]") print(f"Output: {output.flatten()[:10]}") print(f"Expert weights sum: {expert_weights.sum():.3f}")