File size: 3,276 Bytes
104fd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
"""Debug with smaller tensor sizes to isolate the issue."""

import pathlib
import sys
from typing import Optional

import torch


def detect_variant(root: pathlib.Path) -> str:
    build_dir = root / "build"
    variant: Optional[str] = None

    if (root / "kernels" / "utils.py").exists():
        try:
            sys.path.insert(0, str(root))
            from kernels.utils import build_variant as _build_variant  # type: ignore

            variant = _build_variant()
        except Exception:
            variant = None
        finally:
            sys.path.pop(0)

    if variant is None:
        candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*"))
        if candidates:
            variant = candidates[0].name

    if variant is None:
        raise SystemExit("Could not determine build variant; run build.py first.")

    return variant


def main() -> None:
    repo_root = pathlib.Path(__file__).resolve().parent.parent  # Go up from _dev/ to repo root
    variant = detect_variant(repo_root)
    staged_dir = repo_root / "build" / variant

    if str(staged_dir) not in sys.path:
        sys.path.insert(0, str(staged_dir))
    if str(repo_root) not in sys.path:
        sys.path.insert(0, str(repo_root))

    import megablocks  # type: ignore
    from tests.test_gg import gmm, randn  # type: ignore

    print(f"Using staged variant: {variant}")

    # Test with very small sizes first
    for z, m, n, k in [(1, 4, 4, 4), (2, 4, 4, 4), (1, 16, 16, 16), (4, 16, 16, 16)]:
        print(f"\n=== Testing z={z}, m={m}, n={n}, k={k} ===")

        torch.manual_seed(0)
        trans_b = False

        a = randn(z, m, k).view(-1, k)
        b = randn(z, k, n) if not trans_b else randn(z, n, k)
        batch_sizes = torch.tensor([m] * z, device="cpu")

        print(f"a.shape: {a.shape}, b.shape: {b.shape}")
        print(f"Input a range: [{a.min().item():.8f}, {a.max().item():.8f}]")
        print(f"Input b range: [{b.min().item():.8f}, {b.max().item():.8f}]")

        # Reference computation
        a_ref = a.detach().clone().requires_grad_(True)
        b_ref = b.detach().clone().requires_grad_(True)
        ref = gmm(a_ref, b_ref, batch_sizes.cpu(), trans_b)
        print(f"Reference output range: [{ref.min().item():.8f}, {ref.max().item():.8f}]")

        # Megablocks computation
        a.requires_grad_(True)
        b.requires_grad_(True)
        out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
        print(f"Megablocks output range: [{out.min().item():.8f}, {out.max().item():.8f}]")

        # Check for huge values
        huge_values = torch.abs(out) > 1e10
        if huge_values.any():
            print(f"Found {huge_values.sum().item()} huge values out of {out.numel()} total")
            print(f"Max absolute value: {torch.abs(out).max().item():.2e}")

        # Check differences
        if not torch.isnan(out).any() and not torch.isinf(out).any():
            diff = (out - ref).abs().max().item()
            print(f"Max abs diff: {diff:.2e}")
            if diff < 1e-2:
                print("✓ PASS")
            else:
                print("✗ FAIL")
        else:
            print("✗ FAIL (NaN/Inf detected)")


if __name__ == "__main__":
    main()