File size: 4,416 Bytes
104fd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
"""
Isolated test to investigate tensor copy corruption.
Focuses specifically on the z=2 state contamination issue.
"""

import pathlib
import sys
from typing import Optional
import torch

def detect_variant(root: pathlib.Path) -> str:
    build_dir = root / "build"
    variant: Optional[str] = None

    if (root / "kernels" / "utils.py").exists():
        try:
            sys.path.insert(0, str(root))
            from kernels.utils import build_variant as _build_variant  # type: ignore
            variant = _build_variant()
        except Exception:
            variant = None
        finally:
            sys.path.pop(0)

    if variant is None:
        candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*"))
        if candidates:
            variant = candidates[0].name

    if variant is None:
        raise SystemExit("Could not determine build variant; run build.py first.")

    return variant

def randn(bs, x, y):
    """Exact copy of randn from tests/test_gg.py"""
    out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
    return out.cuda().to(torch.bfloat16)

def test_single_z(z, description, megablocks, seed=0):
    """Test a single z value and return result info"""
    print(f"\n=== {description} ===")

    torch.manual_seed(seed)
    m, n, k = 4, 4, 4

    # Create test tensors (matching EXACT pattern from debug-gg-small.py)
    trans_b = False
    a = randn(z, m, k).view(-1, k)
    b = randn(z, k, n) if not trans_b else randn(z, n, k)
    batch_sizes = torch.tensor([m] * z, device="cpu")

    print(f"Input a range: [{a.min().item():.8f}, {a.max().item():.8f}]")
    print(f"Input b range: [{b.min().item():.8f}, {b.max().item():.8f}]")

    # Call megablocks gmm
    out = megablocks.gg_ops.gmm(a, b, batch_sizes, False)

    print(f"Output range: [{out.min().item():.8f}, {out.max().item():.8f}]")

    # Check for huge values (corruption)
    huge_values = torch.abs(out) > 1e10
    has_corruption = huge_values.any()

    if has_corruption:
        print(f"CORRUPTION DETECTED: {huge_values.sum().item()} huge values out of {out.numel()}")
        print(f"Max absolute value: {torch.abs(out).max().item():.2e}")

        # Show specific corrupted positions
        corrupted_indices = torch.where(huge_values)
        if len(corrupted_indices[0]) > 0:
            for i in range(min(3, len(corrupted_indices[0]))):  # Show first 3 corrupted positions
                row, col = corrupted_indices[0][i].item(), corrupted_indices[1][i].item()
                value = out[row, col].item()
                print(f"  Corrupted at [{row}, {col}]: {value:.2e}")
    else:
        print("✓ No corruption detected")

    return has_corruption, out.clone()

def main():
    repo_root = pathlib.Path(__file__).resolve().parent.parent  # Go up from _dev/ to repo root
    variant = detect_variant(repo_root)
    staged_dir = repo_root / "build" / variant

    if str(staged_dir) not in sys.path:
        sys.path.insert(0, str(staged_dir))
    if str(repo_root) not in sys.path:
        sys.path.insert(0, str(repo_root))

    import megablocks  # type: ignore

    print(f"Using staged variant: {variant}")
    print("Testing tensor copy corruption with sequence: z=1, z=2 (using SAME seed like debug-gg-small.py)")

    # Use the same seed for both calls to match debug-gg-small.py exactly
    # This is the key to reproducing the state contamination!

    # First run z=1 (this should work fine)
    z1_corrupted, z1_out = test_single_z(1, "z=1 (baseline)", megablocks, seed=0)

    # Then run z=2 WITHOUT resetting seed (this should show corruption due to state contamination)
    # The key insight: don't reset the random state between calls
    z2_corrupted, z2_out = test_single_z(2, "z=2 (after z=1 - should show corruption)", megablocks, seed=0)

    print(f"\n=== SUMMARY ===")
    print(f"z=1 corrupted: {z1_corrupted}")
    print(f"z=2 corrupted: {z2_corrupted}")

    if z2_corrupted and not z1_corrupted:
        print("✓ Successfully reproduced state contamination bug!")
        print("The corruption happens specifically with z=2 after z=1 has been called.")
    elif z2_corrupted and z1_corrupted:
        print("Both z=1 and z=2 show corruption - this is a different issue.")
    else:
        print("No corruption detected - issue may have been fixed or conditions changed.")

if __name__ == "__main__":
    main()