megablocks-hip / _dev /debug-build-4-megablocks.sh
leonardlin's picture
Add ROCm build debugging utilities
2d8a802
#!/usr/bin/env bash
# Debug script 4: MegaBlocks-specific build debugging
set -euo pipefail
echo "=== MegaBlocks Build Debug Script 4 ==="
echo "Testing MegaBlocks-specific compilation components"
echo
# Set ROCm environment variables
export ROCM_PATH="${ROCM_PATH:-/opt/rocm-7.0.1}"
export ROCM_HOME="${ROCM_HOME:-$ROCM_PATH}"
export HIP_PATH="${HIP_PATH:-$ROCM_PATH}"
export HIP_HOME="${HIP_HOME:-$ROCM_PATH}"
export PATH="$ROCM_HOME/bin:$PATH"
export TORCH_HIP_ARCH_LIST="${TORCH_HIP_ARCH_LIST:-gfx942}"
export HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION:-gfx942}"
export TORCH_EXTENSIONS_DIR="${TORCH_EXTENSIONS_DIR:-$PWD/.torch_extensions_debug}"
echo "Working directory: $(pwd)"
echo
echo "=== Checking MegaBlocks Source Files ==="
echo "Verifying all source files exist:"
sources=(
"torch-ext/torch_binding.cpp"
"csrc/new_cumsum.cu"
"csrc/new_histogram.cu"
"csrc/new_indices.cu"
"csrc/new_replicate.cu"
"csrc/new_sort.cu"
"csrc/grouped_gemm/grouped_gemm.cu"
)
all_exist=true
for src in "${sources[@]}"; do
if [ -f "$src" ]; then
echo "βœ“ $src exists ($(wc -l < "$src") lines)"
else
echo "βœ— $src missing"
all_exist=false
fi
done
if [ "$all_exist" = false ]; then
echo "Cannot proceed - missing source files"
exit 1
fi
echo
echo "=== Checking Include Directories ==="
if [ -d "csrc" ]; then
echo "βœ“ csrc include directory exists"
echo "Headers in csrc/:"
find csrc -name "*.h" -o -name "*.hpp" | head -10
else
echo "βœ— csrc include directory missing"
fi
echo
echo "=== Testing Individual Source Compilation ==="
# Test compiling each .cu file individually
for src in csrc/*.cu; do
if [ -f "$src" ]; then
echo "Testing compilation of $(basename "$src")..."
if timeout 60 hipcc -c "$src" -o "/tmp/$(basename "$src" .cu).o" \
--amdgpu-target=gfx942 \
-I./csrc \
-I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \
-std=c++17 \
-O3 \
-fPIC; then
echo "βœ“ $(basename "$src") compiled successfully"
else
echo "βœ— $(basename "$src") compilation failed"
fi
fi
done
echo
echo "=== Testing grouped_gemm.cu Specifically ==="
echo "This is often the most complex kernel..."
if timeout 120 hipcc -c csrc/grouped_gemm/grouped_gemm.cu -o /tmp/grouped_gemm.o \
--amdgpu-target=gfx942 \
-I./csrc \
-I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \
-std=c++17 \
-O3 \
-fPIC \
-lhipblaslt \
-v; then
echo "βœ“ grouped_gemm.cu compiled successfully"
else
echo "βœ— grouped_gemm.cu compilation failed"
fi
echo
echo "=== Testing torch_binding.cpp ==="
if timeout 60 hipcc -c torch-ext/torch_binding.cpp -o /tmp/torch_binding.o \
-I./csrc \
-I"$(python3 -c 'import torch; print(torch.utils.cpp_extension.include_paths()[0])')" \
-std=c++17 \
-O3 \
-fPIC; then
echo "βœ“ torch_binding.cpp compiled successfully"
else
echo "βœ— torch_binding.cpp compilation failed"
fi
echo
echo "=== Testing Incremental PyTorch Extension Build ==="
cat > debug_build.py << 'EOF'
import os
import pathlib
import sys
import signal
import time
from torch.utils.cpp_extension import load
def timeout_handler(signum, frame):
print("Build timed out - this indicates a hanging issue")
sys.exit(1)
# Set up timeout
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(180) # 3 minute timeout
repo = pathlib.Path(".").resolve()
os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug"))
print("=== Testing with Single Source File ===")
try:
print("Building with just new_cumsum.cu...")
mod = load(
name="_megablocks_debug_single",
sources=["csrc/new_cumsum.cu"],
extra_include_paths=["csrc"],
extra_cflags=["-O3", "-std=c++17"],
extra_cuda_cflags=["-O3"],
verbose=True,
is_python_module=False,
)
print("βœ“ Single source build successful")
except Exception as e:
print(f"βœ— Single source build failed: {e}")
print("\n=== Testing with Two Source Files ===")
try:
print("Building with new_cumsum.cu and new_histogram.cu...")
mod = load(
name="_megablocks_debug_double",
sources=["csrc/new_cumsum.cu", "csrc/new_histogram.cu"],
extra_include_paths=["csrc"],
extra_cflags=["-O3", "-std=c++17"],
extra_cuda_cflags=["-O3"],
verbose=True,
is_python_module=False,
)
print("βœ“ Double source build successful")
except Exception as e:
print(f"βœ— Double source build failed: {e}")
print("\n=== Testing with grouped_gemm.cu Only ===")
try:
print("Building with just grouped_gemm.cu (most complex)...")
mod = load(
name="_megablocks_debug_gemm",
sources=["csrc/grouped_gemm/grouped_gemm.cu"],
extra_include_paths=["csrc"],
extra_cflags=["-O3", "-std=c++17"],
extra_cuda_cflags=["-O3"],
extra_ldflags=["-lhipblaslt"],
verbose=True,
is_python_module=False,
)
print("βœ“ grouped_gemm build successful")
except Exception as e:
print(f"βœ— grouped_gemm build failed: {e}")
signal.alarm(0) # Cancel timeout
EOF
echo "Running incremental build test..."
python3 debug_build.py
echo
echo "=== Testing Full Build with Timeout ==="
cat > debug_full_build.py << 'EOF'
import os
import pathlib
import sys
import signal
from torch.utils.cpp_extension import load
def timeout_handler(signum, frame):
print("Full build timed out - this confirms the hanging issue")
sys.exit(124) # timeout exit code
# Set up 5 minute timeout
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(300)
repo = pathlib.Path(".").resolve()
os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug"))
sources = [
"torch-ext/torch_binding.cpp",
"csrc/new_cumsum.cu",
"csrc/new_histogram.cu",
"csrc/new_indices.cu",
"csrc/new_replicate.cu",
"csrc/new_sort.cu",
"csrc/grouped_gemm/grouped_gemm.cu",
]
print("=== Attempting Full MegaBlocks Build ===")
print("This mimics the exact build.py process...")
print("Sources:", sources)
try:
mod = load(
name="_megablocks_debug_full",
sources=sources,
extra_include_paths=["csrc"],
extra_cflags=["-O3", "-std=c++17"],
extra_cuda_cflags=["-O3"],
extra_ldflags=["-lhipblaslt"],
verbose=True,
is_python_module=False,
)
print("βœ“ Full build successful!")
print("Built:", mod)
except Exception as e:
print(f"βœ— Full build failed: {e}")
import traceback
traceback.print_exc()
signal.alarm(0)
EOF
echo "Running full build test (with timeout)..."
python3 debug_full_build.py
echo
echo "=== Cleanup ==="
rm -f /tmp/*.o
rm -f debug_build.py debug_full_build.py
echo
echo "=== Debug Script 4 Complete ==="