megablocks-hip / _dev /debug-tensor-copy.py
leonardlin's picture
Fix ROCm grouped_gemm accumulation corruption
104fd3c
#!/usr/bin/env python3
"""
Isolated test to investigate tensor copy corruption.
Focuses specifically on the z=2 state contamination issue.
"""
import pathlib
import sys
from typing import Optional
import torch
def detect_variant(root: pathlib.Path) -> str:
build_dir = root / "build"
variant: Optional[str] = None
if (root / "kernels" / "utils.py").exists():
try:
sys.path.insert(0, str(root))
from kernels.utils import build_variant as _build_variant # type: ignore
variant = _build_variant()
except Exception:
variant = None
finally:
sys.path.pop(0)
if variant is None:
candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*"))
if candidates:
variant = candidates[0].name
if variant is None:
raise SystemExit("Could not determine build variant; run build.py first.")
return variant
def randn(bs, x, y):
"""Exact copy of randn from tests/test_gg.py"""
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
return out.cuda().to(torch.bfloat16)
def test_single_z(z, description, megablocks, seed=0):
"""Test a single z value and return result info"""
print(f"\n=== {description} ===")
torch.manual_seed(seed)
m, n, k = 4, 4, 4
# Create test tensors (matching EXACT pattern from debug-gg-small.py)
trans_b = False
a = randn(z, m, k).view(-1, k)
b = randn(z, k, n) if not trans_b else randn(z, n, k)
batch_sizes = torch.tensor([m] * z, device="cpu")
print(f"Input a range: [{a.min().item():.8f}, {a.max().item():.8f}]")
print(f"Input b range: [{b.min().item():.8f}, {b.max().item():.8f}]")
# Call megablocks gmm
out = megablocks.gg_ops.gmm(a, b, batch_sizes, False)
print(f"Output range: [{out.min().item():.8f}, {out.max().item():.8f}]")
# Check for huge values (corruption)
huge_values = torch.abs(out) > 1e10
has_corruption = huge_values.any()
if has_corruption:
print(f"CORRUPTION DETECTED: {huge_values.sum().item()} huge values out of {out.numel()}")
print(f"Max absolute value: {torch.abs(out).max().item():.2e}")
# Show specific corrupted positions
corrupted_indices = torch.where(huge_values)
if len(corrupted_indices[0]) > 0:
for i in range(min(3, len(corrupted_indices[0]))): # Show first 3 corrupted positions
row, col = corrupted_indices[0][i].item(), corrupted_indices[1][i].item()
value = out[row, col].item()
print(f" Corrupted at [{row}, {col}]: {value:.2e}")
else:
print("✓ No corruption detected")
return has_corruption, out.clone()
def main():
repo_root = pathlib.Path(__file__).resolve().parent.parent # Go up from _dev/ to repo root
variant = detect_variant(repo_root)
staged_dir = repo_root / "build" / variant
if str(staged_dir) not in sys.path:
sys.path.insert(0, str(staged_dir))
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
import megablocks # type: ignore
print(f"Using staged variant: {variant}")
print("Testing tensor copy corruption with sequence: z=1, z=2 (using SAME seed like debug-gg-small.py)")
# Use the same seed for both calls to match debug-gg-small.py exactly
# This is the key to reproducing the state contamination!
# First run z=1 (this should work fine)
z1_corrupted, z1_out = test_single_z(1, "z=1 (baseline)", megablocks, seed=0)
# Then run z=2 WITHOUT resetting seed (this should show corruption due to state contamination)
# The key insight: don't reset the random state between calls
z2_corrupted, z2_out = test_single_z(2, "z=2 (after z=1 - should show corruption)", megablocks, seed=0)
print(f"\n=== SUMMARY ===")
print(f"z=1 corrupted: {z1_corrupted}")
print(f"z=2 corrupted: {z2_corrupted}")
if z2_corrupted and not z1_corrupted:
print("✓ Successfully reproduced state contamination bug!")
print("The corruption happens specifically with z=2 after z=1 has been called.")
elif z2_corrupted and z1_corrupted:
print("Both z=1 and z=2 show corruption - this is a different issue.")
else:
print("No corruption detected - issue may have been fixed or conditions changed.")
if __name__ == "__main__":
main()