|
|
|
|
|
""" |
|
|
Isolated test to investigate tensor copy corruption. |
|
|
Focuses specifically on the z=2 state contamination issue. |
|
|
""" |
|
|
|
|
|
import pathlib |
|
|
import sys |
|
|
from typing import Optional |
|
|
import torch |
|
|
|
|
|
def detect_variant(root: pathlib.Path) -> str: |
|
|
build_dir = root / "build" |
|
|
variant: Optional[str] = None |
|
|
|
|
|
if (root / "kernels" / "utils.py").exists(): |
|
|
try: |
|
|
sys.path.insert(0, str(root)) |
|
|
from kernels.utils import build_variant as _build_variant |
|
|
variant = _build_variant() |
|
|
except Exception: |
|
|
variant = None |
|
|
finally: |
|
|
sys.path.pop(0) |
|
|
|
|
|
if variant is None: |
|
|
candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*")) |
|
|
if candidates: |
|
|
variant = candidates[0].name |
|
|
|
|
|
if variant is None: |
|
|
raise SystemExit("Could not determine build variant; run build.py first.") |
|
|
|
|
|
return variant |
|
|
|
|
|
def randn(bs, x, y): |
|
|
"""Exact copy of randn from tests/test_gg.py""" |
|
|
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x) |
|
|
return out.cuda().to(torch.bfloat16) |
|
|
|
|
|
def test_single_z(z, description, megablocks, seed=0): |
|
|
"""Test a single z value and return result info""" |
|
|
print(f"\n=== {description} ===") |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
m, n, k = 4, 4, 4 |
|
|
|
|
|
|
|
|
trans_b = False |
|
|
a = randn(z, m, k).view(-1, k) |
|
|
b = randn(z, k, n) if not trans_b else randn(z, n, k) |
|
|
batch_sizes = torch.tensor([m] * z, device="cpu") |
|
|
|
|
|
print(f"Input a range: [{a.min().item():.8f}, {a.max().item():.8f}]") |
|
|
print(f"Input b range: [{b.min().item():.8f}, {b.max().item():.8f}]") |
|
|
|
|
|
|
|
|
out = megablocks.gg_ops.gmm(a, b, batch_sizes, False) |
|
|
|
|
|
print(f"Output range: [{out.min().item():.8f}, {out.max().item():.8f}]") |
|
|
|
|
|
|
|
|
huge_values = torch.abs(out) > 1e10 |
|
|
has_corruption = huge_values.any() |
|
|
|
|
|
if has_corruption: |
|
|
print(f"CORRUPTION DETECTED: {huge_values.sum().item()} huge values out of {out.numel()}") |
|
|
print(f"Max absolute value: {torch.abs(out).max().item():.2e}") |
|
|
|
|
|
|
|
|
corrupted_indices = torch.where(huge_values) |
|
|
if len(corrupted_indices[0]) > 0: |
|
|
for i in range(min(3, len(corrupted_indices[0]))): |
|
|
row, col = corrupted_indices[0][i].item(), corrupted_indices[1][i].item() |
|
|
value = out[row, col].item() |
|
|
print(f" Corrupted at [{row}, {col}]: {value:.2e}") |
|
|
else: |
|
|
print("✓ No corruption detected") |
|
|
|
|
|
return has_corruption, out.clone() |
|
|
|
|
|
def main(): |
|
|
repo_root = pathlib.Path(__file__).resolve().parent.parent |
|
|
variant = detect_variant(repo_root) |
|
|
staged_dir = repo_root / "build" / variant |
|
|
|
|
|
if str(staged_dir) not in sys.path: |
|
|
sys.path.insert(0, str(staged_dir)) |
|
|
if str(repo_root) not in sys.path: |
|
|
sys.path.insert(0, str(repo_root)) |
|
|
|
|
|
import megablocks |
|
|
|
|
|
print(f"Using staged variant: {variant}") |
|
|
print("Testing tensor copy corruption with sequence: z=1, z=2 (using SAME seed like debug-gg-small.py)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
z1_corrupted, z1_out = test_single_z(1, "z=1 (baseline)", megablocks, seed=0) |
|
|
|
|
|
|
|
|
|
|
|
z2_corrupted, z2_out = test_single_z(2, "z=2 (after z=1 - should show corruption)", megablocks, seed=0) |
|
|
|
|
|
print(f"\n=== SUMMARY ===") |
|
|
print(f"z=1 corrupted: {z1_corrupted}") |
|
|
print(f"z=2 corrupted: {z2_corrupted}") |
|
|
|
|
|
if z2_corrupted and not z1_corrupted: |
|
|
print("✓ Successfully reproduced state contamination bug!") |
|
|
print("The corruption happens specifically with z=2 after z=1 has been called.") |
|
|
elif z2_corrupted and z1_corrupted: |
|
|
print("Both z=1 and z=2 show corruption - this is a different issue.") |
|
|
else: |
|
|
print("No corruption detected - issue may have been fixed or conditions changed.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |