|
|
|
|
|
"""Step-by-step debugging of the grouped GEMM computation.""" |
|
|
|
|
|
import pathlib |
|
|
import sys |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
def detect_variant(root: pathlib.Path) -> str: |
|
|
build_dir = root / "build" |
|
|
variant: Optional[str] = None |
|
|
|
|
|
if (root / "kernels" / "utils.py").exists(): |
|
|
try: |
|
|
sys.path.insert(0, str(root)) |
|
|
from kernels.utils import build_variant as _build_variant |
|
|
|
|
|
variant = _build_variant() |
|
|
except Exception: |
|
|
variant = None |
|
|
finally: |
|
|
sys.path.pop(0) |
|
|
|
|
|
if variant is None: |
|
|
candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*")) |
|
|
if candidates: |
|
|
variant = candidates[0].name |
|
|
|
|
|
if variant is None: |
|
|
raise SystemExit("Could not determine build variant; run build.py first.") |
|
|
|
|
|
return variant |
|
|
|
|
|
|
|
|
def manual_gmm_computation(a, b, batch_sizes, trans_b=False): |
|
|
"""Manual step-by-step computation like the C++ code does.""" |
|
|
print("=== Manual GMM computation ===") |
|
|
|
|
|
|
|
|
batch_sizes_cpu = batch_sizes.cpu() |
|
|
counts_ptr = batch_sizes_cpu.numpy() |
|
|
num_experts = len(counts_ptr) |
|
|
|
|
|
|
|
|
prefix = [] |
|
|
running = 0 |
|
|
for i in range(num_experts): |
|
|
running += counts_ptr[i] |
|
|
prefix.append(running) |
|
|
|
|
|
tokens = prefix[-1] if prefix else 0 |
|
|
print(f"num_experts: {num_experts}, tokens: {tokens}") |
|
|
print(f"a.shape: {a.shape}, b.shape: {b.shape}") |
|
|
print(f"batch_sizes: {counts_ptr}") |
|
|
|
|
|
|
|
|
if not trans_b: |
|
|
hidden_out = a.size(1) |
|
|
hidden_in = b.size(2) |
|
|
out = torch.empty((tokens, hidden_in), dtype=a.dtype, device=a.device) |
|
|
print(f"Output shape: {out.shape} (tokens={tokens}, hidden_in={hidden_in})") |
|
|
|
|
|
b_contig = b.contiguous() |
|
|
|
|
|
start = 0 |
|
|
for expert in range(num_experts): |
|
|
end = prefix[expert] |
|
|
rows = end - start |
|
|
print(f"\nExpert {expert}: start={start}, end={end}, rows={rows}") |
|
|
|
|
|
if rows == 0: |
|
|
start = end |
|
|
continue |
|
|
|
|
|
|
|
|
a_slice = a.narrow(0, start, rows) |
|
|
b_slice = b_contig.select(0, expert) |
|
|
out_slice = out.narrow(0, start, rows) |
|
|
|
|
|
print(f" a_slice.shape: {a_slice.shape}") |
|
|
print(f" b_slice.shape: {b_slice.shape}") |
|
|
print(f" a_slice range: [{a_slice.min().item():.8f}, {a_slice.max().item():.8f}]") |
|
|
print(f" b_slice range: [{b_slice.min().item():.8f}, {b_slice.max().item():.8f}]") |
|
|
|
|
|
|
|
|
a_f32 = a_slice.to(torch.float32) |
|
|
b_f32 = b_slice.to(torch.float32) |
|
|
|
|
|
|
|
|
prod = torch.matmul(a_f32, b_f32) |
|
|
print(f" prod.shape: {prod.shape}") |
|
|
print(f" prod range: [{prod.min().item():.8f}, {prod.max().item():.8f}]") |
|
|
|
|
|
|
|
|
prod_bf16 = prod.to(a.dtype) |
|
|
out_slice.copy_(prod_bf16) |
|
|
|
|
|
start = end |
|
|
|
|
|
return out |
|
|
else: |
|
|
raise NotImplementedError("trans_b case not implemented") |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
repo_root = pathlib.Path(__file__).resolve().parent.parent |
|
|
variant = detect_variant(repo_root) |
|
|
staged_dir = repo_root / "build" / variant |
|
|
|
|
|
if str(staged_dir) not in sys.path: |
|
|
sys.path.insert(0, str(staged_dir)) |
|
|
if str(repo_root) not in sys.path: |
|
|
sys.path.insert(0, str(repo_root)) |
|
|
|
|
|
import megablocks |
|
|
from tests.test_gg import gmm, randn |
|
|
|
|
|
print(f"Using staged variant: {variant}") |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
z = m = n = k = 128 |
|
|
trans_b = False |
|
|
|
|
|
a = randn(z, m, k).view(-1, k) |
|
|
b = randn(z, k, n) if not trans_b else randn(z, n, k) |
|
|
batch_sizes = torch.tensor([m] * z, device="cpu") |
|
|
|
|
|
print(f"=== Input Information ===") |
|
|
print(f"a.shape: {a.shape}, dtype: {a.dtype}") |
|
|
print(f"b.shape: {b.shape}, dtype: {b.dtype}") |
|
|
print(f"batch_sizes: {batch_sizes}") |
|
|
print(f"Input a range: [{a.min().item():.8f}, {a.max().item():.8f}]") |
|
|
print(f"Input b range: [{b.min().item():.8f}, {b.max().item():.8f}]") |
|
|
|
|
|
|
|
|
manual_out = manual_gmm_computation(a.clone(), b.clone(), batch_sizes, trans_b) |
|
|
print(f"\nManual output range: [{manual_out.min().item():.8f}, {manual_out.max().item():.8f}]") |
|
|
|
|
|
|
|
|
a_ref = a.detach().clone() |
|
|
b_ref = b.detach().clone() |
|
|
ref = gmm(a_ref, b_ref, batch_sizes.cpu(), trans_b) |
|
|
print(f"Reference output range: [{ref.min().item():.8f}, {ref.max().item():.8f}]") |
|
|
|
|
|
|
|
|
out = megablocks.gg_ops.gmm(a.clone(), b.clone(), batch_sizes, trans_b) |
|
|
print(f"Megablocks output range: [{out.min().item():.8f}, {out.max().item():.8f}]") |
|
|
|
|
|
|
|
|
manual_vs_ref = (manual_out - ref).abs().max().item() |
|
|
megablocks_vs_ref = (out - ref).abs().max().item() |
|
|
print(f"\nManual vs Reference max diff: {manual_vs_ref:.8e}") |
|
|
print(f"Megablocks vs Reference max diff: {megablocks_vs_ref:.8e}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |