|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from collections import namedtuple | 
					
						
						|  | from pathlib import Path | 
					
						
						|  |  | 
					
						
						|  | from kernels import get_kernel, get_local_kernel | 
					
						
						|  | from kernels.utils import build_variant | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | torch.manual_seed(42) | 
					
						
						|  | if torch.cuda.is_available(): | 
					
						
						|  | torch.cuda.manual_seed(42) | 
					
						
						|  |  | 
					
						
						|  | repo_root = Path(__file__).resolve().parent | 
					
						
						|  | megablocks = None | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | megablocks = get_kernel("shisa-ai/megablocks-hip") | 
					
						
						|  | print("MegaBlocks kernel downloaded successfully.") | 
					
						
						|  | except FileNotFoundError: | 
					
						
						|  |  | 
					
						
						|  | variant = build_variant() | 
					
						
						|  | local_pkg = repo_root / "build" / variant | 
					
						
						|  | print(f"Hub build missing for {variant}; falling back to {local_pkg}") | 
					
						
						|  | megablocks = get_local_kernel(local_pkg, "megablocks") | 
					
						
						|  | print("MegaBlocks kernel loaded from local build.") | 
					
						
						|  |  | 
					
						
						|  | model = megablocks.layers.MegaBlocksMoeMLP() | 
					
						
						|  | model.experts = namedtuple("Experts", ["gate_up_proj", "gate_down_proj", "down_proj", "hidden_size"]) | 
					
						
						|  | print("MegaBlocksMoeMLP instance created successfully.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ne, hs, isz = 128, 1152, 3072 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | device = "cuda" if torch.cuda.is_available() else "cpu" | 
					
						
						|  | model.router = torch.nn.Linear(hs, ne, device=device) | 
					
						
						|  | torch.nn.init.kaiming_uniform_(model.router.weight) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | e = model.experts | 
					
						
						|  | e.gate_up_proj = torch.nn.Parameter(torch.randn(ne, hs, isz, device=device) * 0.02) | 
					
						
						|  | e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device=device)) | 
					
						
						|  | e.down_proj = torch.nn.Parameter(torch.randn(ne, 1536, hs, device=device) * 0.02) | 
					
						
						|  | e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device=device)) | 
					
						
						|  | e.hidden_size = hs | 
					
						
						|  | print("Expert layers initialized successfully.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x = torch.randn(1, 1, hs, device=device) * 0.1 | 
					
						
						|  | output, expert_weights = model(x) | 
					
						
						|  | print("Model forward pass completed successfully.") | 
					
						
						|  |  | 
					
						
						|  | print(f"Output shape: {output.shape}") | 
					
						
						|  | print(f"Output range: [{output.min():.3f}, {output.max():.3f}]") | 
					
						
						|  | print(f"Output: {output.flatten()[:10]}") | 
					
						
						|  | print(f"Expert weights sum: {expert_weights.sum():.3f}") | 
					
						
						|  |  |