megablocks-hip / _dev /debug-gg-detailed.py
leonardlin's picture
Fix ROCm grouped_gemm accumulation corruption
104fd3c
#!/usr/bin/env python3
"""Enhanced numerical diagnostic for megablocks.gg_ops.gmm on ROCm builds."""
import pathlib
import sys
from typing import Optional
import torch
def detect_variant(root: pathlib.Path) -> str:
build_dir = root / "build"
variant: Optional[str] = None
if (root / "kernels" / "utils.py").exists():
try:
sys.path.insert(0, str(root))
from kernels.utils import build_variant as _build_variant # type: ignore
variant = _build_variant()
except Exception:
variant = None
finally:
sys.path.pop(0)
if variant is None:
candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*"))
if candidates:
variant = candidates[0].name
if variant is None:
raise SystemExit("Could not determine build variant; run build.py first.")
return variant
def main() -> None:
repo_root = pathlib.Path(__file__).resolve().parent.parent # Go up from _dev/ to repo root
variant = detect_variant(repo_root)
staged_dir = repo_root / "build" / variant
if str(staged_dir) not in sys.path:
sys.path.insert(0, str(staged_dir))
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
import megablocks # type: ignore
from tests.test_gg import gmm, randn # type: ignore
print(f"Using staged variant: {variant}")
print(f"megablocks module: {megablocks.__file__}")
torch.manual_seed(0)
z = m = n = k = 128
trans_b = False
a = randn(z, m, k).view(-1, k)
b = randn(z, k, n) if not trans_b else randn(z, n, k)
batch_sizes = torch.tensor([m] * z, device="cpu")
# Check input tensors for NaNs
print(f"Input a has NaNs: {torch.isnan(a).any().item()}")
print(f"Input b has NaNs: {torch.isnan(b).any().item()}")
print(f"Input a range: [{a.min().item():.6f}, {a.max().item():.6f}]")
print(f"Input b range: [{b.min().item():.6f}, {b.max().item():.6f}]")
a.requires_grad_(True)
b.requires_grad_(True)
a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)
# First run reference computation
ref = gmm(a_ref, b_ref, batch_sizes.cpu(), trans_b)
print(f"Reference computation completed")
print(f"ref has NaNs: {torch.isnan(ref).any().item()}")
print(f"ref range: [{ref.min().item():.6f}, {ref.max().item():.6f}]")
# Now run the problematic implementation
print(f"Running megablocks.gg_ops.gmm...")
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
print(f"megablocks computation completed")
print(f"out has NaNs: {torch.isnan(out).any().item()}")
if not torch.isnan(out).all():
print(f"out range: [{out.min().item():.6f}, {out.max().item():.6f}]")
else:
print("out is all NaN")
# Check if inputs were modified (shouldn't happen with NoGradGuard)
print(f"Input a modified: {not torch.equal(a[:5], a_ref[:5])}")
print(f"Input b modified: {not torch.equal(b[0, :5, :5], b_ref[0, :5, :5])}")
if not torch.isnan(out).any():
forward_abs = (out - ref).abs().max().item()
forward_rel = ((out - ref).abs() / (ref.abs() + 1e-9)).max().item()
print(f"forward max abs diff: {forward_abs:.6e}")
print(f"forward max rel diff: {forward_rel:.6e}")
else:
print(f"forward max abs diff: nan")
print(f"forward max rel diff: nan")
# Test gradients
out.sum().backward()
ref.sum().backward()
print(f"a.grad has NaNs: {torch.isnan(a.grad).any().item()}")
print(f"b.grad has NaNs: {torch.isnan(b.grad).any().item()}")
if not torch.isnan(a.grad).any() and not torch.isnan(a_ref.grad).any():
a_grad_abs = (a.grad - a_ref.grad).abs().max().item()
print(f"a grad max abs diff: {a_grad_abs:.6e}")
else:
print(f"a grad max abs diff: nan")
if not torch.isnan(b.grad).any() and not torch.isnan(b_ref.grad).any():
b_grad_abs = (b.grad - b_ref.grad).abs().max().item()
print(f"b grad max abs diff: {b_grad_abs:.6e}")
else:
print(f"b grad max abs diff: nan")
if __name__ == "__main__":
main()