File size: 2,218 Bytes
			
			| 1e407f0 09e15a7 1e407f0 09e15a7 1e407f0 09e15a7 1e407f0 09e15a7 1e407f0 09e15a7 1e407f0 09e15a7 1e407f0 09e15a7 1e407f0 09e15a7 1e407f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | 
# /// script
# requires-python = "==3.10"
# dependencies = [
#     "numpy",
#     "kernels",
#     "torch"
# ]
# ///
import torch
from collections import namedtuple
from pathlib import Path
from kernels import get_kernel, get_local_kernel
from kernels.utils import build_variant
# Make reproducible
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
repo_root = Path(__file__).resolve().parent
megablocks = None
try:
    # Prefer the published kernel
    megablocks = get_kernel("shisa-ai/megablocks-hip")
    print("MegaBlocks kernel downloaded successfully.")
except FileNotFoundError:
    # Fall back to the locally staged ROCm build (produced by build.py)
    variant = build_variant()
    local_pkg = repo_root / "build" / variant
    print(f"Hub build missing for {variant}; falling back to {local_pkg}")
    megablocks = get_local_kernel(local_pkg, "megablocks")
    print("MegaBlocks kernel loaded from local build.")
model = megablocks.layers.MegaBlocksMoeMLP()
model.experts = namedtuple("Experts", ["gate_up_proj", "gate_down_proj", "down_proj", "hidden_size"])
print("MegaBlocksMoeMLP instance created successfully.")
# Config
ne, hs, isz = 128, 1152, 3072
# Router with proper initialization
device = "cuda" if torch.cuda.is_available() else "cpu"
model.router = torch.nn.Linear(hs, ne, device=device)
torch.nn.init.kaiming_uniform_(model.router.weight)
# Expert layers with realistic weights
e = model.experts
e.gate_up_proj = torch.nn.Parameter(torch.randn(ne, hs, isz, device=device) * 0.02)
e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device=device))
e.down_proj = torch.nn.Parameter(torch.randn(ne, 1536, hs, device=device) * 0.02)
e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device=device))
e.hidden_size = hs
print("Expert layers initialized successfully.")
# Test with normalized input
x = torch.randn(1, 1, hs, device=device) * 0.1
output, expert_weights = model(x)
print("Model forward pass completed successfully.")
print(f"Output shape: {output.shape}")
print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")
print(f"Output: {output.flatten()[:10]}")
print(f"Expert weights sum: {expert_weights.sum():.3f}")
 | 
