|
|
|
|
|
"""Debug with smaller tensor sizes to isolate the issue.""" |
|
|
|
|
|
import pathlib |
|
|
import sys |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
def detect_variant(root: pathlib.Path) -> str: |
|
|
build_dir = root / "build" |
|
|
variant: Optional[str] = None |
|
|
|
|
|
if (root / "kernels" / "utils.py").exists(): |
|
|
try: |
|
|
sys.path.insert(0, str(root)) |
|
|
from kernels.utils import build_variant as _build_variant |
|
|
|
|
|
variant = _build_variant() |
|
|
except Exception: |
|
|
variant = None |
|
|
finally: |
|
|
sys.path.pop(0) |
|
|
|
|
|
if variant is None: |
|
|
candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*")) |
|
|
if candidates: |
|
|
variant = candidates[0].name |
|
|
|
|
|
if variant is None: |
|
|
raise SystemExit("Could not determine build variant; run build.py first.") |
|
|
|
|
|
return variant |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
repo_root = pathlib.Path(__file__).resolve().parent.parent |
|
|
variant = detect_variant(repo_root) |
|
|
staged_dir = repo_root / "build" / variant |
|
|
|
|
|
if str(staged_dir) not in sys.path: |
|
|
sys.path.insert(0, str(staged_dir)) |
|
|
if str(repo_root) not in sys.path: |
|
|
sys.path.insert(0, str(repo_root)) |
|
|
|
|
|
import megablocks |
|
|
from tests.test_gg import gmm, randn |
|
|
|
|
|
print(f"Using staged variant: {variant}") |
|
|
|
|
|
|
|
|
for z, m, n, k in [(1, 4, 4, 4), (2, 4, 4, 4), (1, 16, 16, 16), (4, 16, 16, 16)]: |
|
|
print(f"\n=== Testing z={z}, m={m}, n={n}, k={k} ===") |
|
|
|
|
|
torch.manual_seed(0) |
|
|
trans_b = False |
|
|
|
|
|
a = randn(z, m, k).view(-1, k) |
|
|
b = randn(z, k, n) if not trans_b else randn(z, n, k) |
|
|
batch_sizes = torch.tensor([m] * z, device="cpu") |
|
|
|
|
|
print(f"a.shape: {a.shape}, b.shape: {b.shape}") |
|
|
print(f"Input a range: [{a.min().item():.8f}, {a.max().item():.8f}]") |
|
|
print(f"Input b range: [{b.min().item():.8f}, {b.max().item():.8f}]") |
|
|
|
|
|
|
|
|
a_ref = a.detach().clone().requires_grad_(True) |
|
|
b_ref = b.detach().clone().requires_grad_(True) |
|
|
ref = gmm(a_ref, b_ref, batch_sizes.cpu(), trans_b) |
|
|
print(f"Reference output range: [{ref.min().item():.8f}, {ref.max().item():.8f}]") |
|
|
|
|
|
|
|
|
a.requires_grad_(True) |
|
|
b.requires_grad_(True) |
|
|
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) |
|
|
print(f"Megablocks output range: [{out.min().item():.8f}, {out.max().item():.8f}]") |
|
|
|
|
|
|
|
|
huge_values = torch.abs(out) > 1e10 |
|
|
if huge_values.any(): |
|
|
print(f"Found {huge_values.sum().item()} huge values out of {out.numel()} total") |
|
|
print(f"Max absolute value: {torch.abs(out).max().item():.2e}") |
|
|
|
|
|
|
|
|
if not torch.isnan(out).any() and not torch.isinf(out).any(): |
|
|
diff = (out - ref).abs().max().item() |
|
|
print(f"Max abs diff: {diff:.2e}") |
|
|
if diff < 1e-2: |
|
|
print("β PASS") |
|
|
else: |
|
|
print("β FAIL") |
|
|
else: |
|
|
print("β FAIL (NaN/Inf detected)") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |