megablocks-hip / _dev /debug-gg-step-by-step.py
leonardlin's picture
Fix ROCm grouped_gemm accumulation corruption
104fd3c
#!/usr/bin/env python3
"""Step-by-step debugging of the grouped GEMM computation."""
import pathlib
import sys
from typing import Optional
import torch
def detect_variant(root: pathlib.Path) -> str:
build_dir = root / "build"
variant: Optional[str] = None
if (root / "kernels" / "utils.py").exists():
try:
sys.path.insert(0, str(root))
from kernels.utils import build_variant as _build_variant # type: ignore
variant = _build_variant()
except Exception:
variant = None
finally:
sys.path.pop(0)
if variant is None:
candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*"))
if candidates:
variant = candidates[0].name
if variant is None:
raise SystemExit("Could not determine build variant; run build.py first.")
return variant
def manual_gmm_computation(a, b, batch_sizes, trans_b=False):
"""Manual step-by-step computation like the C++ code does."""
print("=== Manual GMM computation ===")
# Convert to CPU for batch sizes
batch_sizes_cpu = batch_sizes.cpu()
counts_ptr = batch_sizes_cpu.numpy()
num_experts = len(counts_ptr)
# Calculate prefix sums like the C++ code
prefix = []
running = 0
for i in range(num_experts):
running += counts_ptr[i]
prefix.append(running)
tokens = prefix[-1] if prefix else 0
print(f"num_experts: {num_experts}, tokens: {tokens}")
print(f"a.shape: {a.shape}, b.shape: {b.shape}")
print(f"batch_sizes: {counts_ptr}")
# Create output tensor
if not trans_b: # default case
hidden_out = a.size(1) # 128
hidden_in = b.size(2) # 128
out = torch.empty((tokens, hidden_in), dtype=a.dtype, device=a.device)
print(f"Output shape: {out.shape} (tokens={tokens}, hidden_in={hidden_in})")
b_contig = b.contiguous()
start = 0
for expert in range(num_experts):
end = prefix[expert]
rows = end - start
print(f"\nExpert {expert}: start={start}, end={end}, rows={rows}")
if rows == 0:
start = end
continue
# Get slices like C++ code
a_slice = a.narrow(0, start, rows) # [rows, hidden_out]
b_slice = b_contig.select(0, expert) # [hidden_out, hidden_in]
out_slice = out.narrow(0, start, rows) # [rows, hidden_in]
print(f" a_slice.shape: {a_slice.shape}")
print(f" b_slice.shape: {b_slice.shape}")
print(f" a_slice range: [{a_slice.min().item():.8f}, {a_slice.max().item():.8f}]")
print(f" b_slice range: [{b_slice.min().item():.8f}, {b_slice.max().item():.8f}]")
# Convert to FP32 like C++ code
a_f32 = a_slice.to(torch.float32)
b_f32 = b_slice.to(torch.float32)
# Do the matmul
prod = torch.matmul(a_f32, b_f32)
print(f" prod.shape: {prod.shape}")
print(f" prod range: [{prod.min().item():.8f}, {prod.max().item():.8f}]")
# Convert back and copy
prod_bf16 = prod.to(a.dtype)
out_slice.copy_(prod_bf16)
start = end
return out
else:
raise NotImplementedError("trans_b case not implemented")
def main() -> None:
repo_root = pathlib.Path(__file__).resolve().parent.parent # Go up from _dev/ to repo root
variant = detect_variant(repo_root)
staged_dir = repo_root / "build" / variant
if str(staged_dir) not in sys.path:
sys.path.insert(0, str(staged_dir))
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
import megablocks # type: ignore
from tests.test_gg import gmm, randn # type: ignore
print(f"Using staged variant: {variant}")
torch.manual_seed(0)
z = m = n = k = 128
trans_b = False
a = randn(z, m, k).view(-1, k)
b = randn(z, k, n) if not trans_b else randn(z, n, k)
batch_sizes = torch.tensor([m] * z, device="cpu")
print(f"=== Input Information ===")
print(f"a.shape: {a.shape}, dtype: {a.dtype}")
print(f"b.shape: {b.shape}, dtype: {b.dtype}")
print(f"batch_sizes: {batch_sizes}")
print(f"Input a range: [{a.min().item():.8f}, {a.max().item():.8f}]")
print(f"Input b range: [{b.min().item():.8f}, {b.max().item():.8f}]")
# Manual computation
manual_out = manual_gmm_computation(a.clone(), b.clone(), batch_sizes, trans_b)
print(f"\nManual output range: [{manual_out.min().item():.8f}, {manual_out.max().item():.8f}]")
# Reference computation
a_ref = a.detach().clone()
b_ref = b.detach().clone()
ref = gmm(a_ref, b_ref, batch_sizes.cpu(), trans_b)
print(f"Reference output range: [{ref.min().item():.8f}, {ref.max().item():.8f}]")
# Megablocks computation
out = megablocks.gg_ops.gmm(a.clone(), b.clone(), batch_sizes, trans_b)
print(f"Megablocks output range: [{out.min().item():.8f}, {out.max().item():.8f}]")
# Compare
manual_vs_ref = (manual_out - ref).abs().max().item()
megablocks_vs_ref = (out - ref).abs().max().item()
print(f"\nManual vs Reference max diff: {manual_vs_ref:.8e}")
print(f"Megablocks vs Reference max diff: {megablocks_vs_ref:.8e}")
if __name__ == "__main__":
main()