#!/usr/bin/env python3 """Numerical diagnostic for megablocks.gg_ops.gmm on ROCm builds.""" import pathlib import sys from typing import Optional import torch def detect_variant(root: pathlib.Path) -> str: build_dir = root / "build" variant: Optional[str] = None if (root / "kernels" / "utils.py").exists(): try: sys.path.insert(0, str(root)) from kernels.utils import build_variant as _build_variant # type: ignore variant = _build_variant() except Exception: variant = None finally: sys.path.pop(0) if variant is None: candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*")) if candidates: variant = candidates[0].name if variant is None: raise SystemExit("Could not determine build variant; run build.py first.") return variant def main() -> None: repo_root = pathlib.Path(__file__).resolve().parent.parent # Go up from _dev/ to repo root variant = detect_variant(repo_root) staged_dir = repo_root / "build" / variant if str(staged_dir) not in sys.path: sys.path.insert(0, str(staged_dir)) if str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) import megablocks # type: ignore from tests.test_gg import gmm, randn # type: ignore print(f"Using staged variant: {variant}") print(f"megablocks module: {megablocks.__file__}") torch.manual_seed(0) z = m = n = k = 128 trans_b = False a = randn(z, m, k).view(-1, k) b = randn(z, k, n) if not trans_b else randn(z, n, k) batch_sizes = torch.tensor([m] * z, device="cpu") a.requires_grad_(True) b.requires_grad_(True) a_ref = a.detach().clone().requires_grad_(True) b_ref = b.detach().clone().requires_grad_(True) out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) ref = gmm(a_ref, b_ref, batch_sizes.cpu(), trans_b) print(f"out has NaNs: {torch.isnan(out).any().item()}") print(f"ref has NaNs: {torch.isnan(ref).any().item()}") forward_abs = (out - ref).abs().max().item() forward_rel = ((out - ref).abs() / (ref.abs() + 1e-9)).max().item() out.sum().backward() ref.sum().backward() print(f"a.grad has NaNs: {torch.isnan(a.grad).any().item()}") print(f"b.grad has NaNs: {torch.isnan(b.grad).any().item()}") a_grad_abs = (a.grad - a_ref.grad).abs().max().item() b_grad_abs = (b.grad - b_ref.grad).abs().max().item() print(f"forward max abs diff: {forward_abs:.6e}") print(f"forward max rel diff: {forward_rel:.6e}") print(f"a grad max abs diff: {a_grad_abs:.6e}") print(f"b grad max abs diff: {b_grad_abs:.6e}") if __name__ == "__main__": main()