|
|
import os |
|
|
import pathlib |
|
|
import sys |
|
|
import signal |
|
|
import time |
|
|
from torch.utils.cpp_extension import load |
|
|
|
|
|
def timeout_handler(signum, frame): |
|
|
print("Build timed out - this indicates a hanging issue") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
signal.signal(signal.SIGALRM, timeout_handler) |
|
|
signal.alarm(180) |
|
|
|
|
|
repo = pathlib.Path(".").resolve() |
|
|
os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions_debug")) |
|
|
|
|
|
print("=== Testing with Single Source File ===") |
|
|
try: |
|
|
print("Building with just new_cumsum.cu...") |
|
|
mod = load( |
|
|
name="_megablocks_debug_single", |
|
|
sources=["csrc/new_cumsum.cu"], |
|
|
extra_include_paths=["csrc"], |
|
|
extra_cflags=["-O3", "-std=c++17"], |
|
|
extra_cuda_cflags=["-O3"], |
|
|
verbose=True, |
|
|
is_python_module=False, |
|
|
) |
|
|
print("β Single source build successful") |
|
|
except Exception as e: |
|
|
print(f"β Single source build failed: {e}") |
|
|
|
|
|
print("\n=== Testing with Two Source Files ===") |
|
|
try: |
|
|
print("Building with new_cumsum.cu and new_histogram.cu...") |
|
|
mod = load( |
|
|
name="_megablocks_debug_double", |
|
|
sources=["csrc/new_cumsum.cu", "csrc/new_histogram.cu"], |
|
|
extra_include_paths=["csrc"], |
|
|
extra_cflags=["-O3", "-std=c++17"], |
|
|
extra_cuda_cflags=["-O3"], |
|
|
verbose=True, |
|
|
is_python_module=False, |
|
|
) |
|
|
print("β Double source build successful") |
|
|
except Exception as e: |
|
|
print(f"β Double source build failed: {e}") |
|
|
|
|
|
print("\n=== Testing with grouped_gemm.cu Only ===") |
|
|
try: |
|
|
print("Building with just grouped_gemm.cu (most complex)...") |
|
|
mod = load( |
|
|
name="_megablocks_debug_gemm", |
|
|
sources=["csrc/grouped_gemm/grouped_gemm.cu"], |
|
|
extra_include_paths=["csrc"], |
|
|
extra_cflags=["-O3", "-std=c++17"], |
|
|
extra_cuda_cflags=["-O3"], |
|
|
extra_ldflags=["-lhipblaslt"], |
|
|
verbose=True, |
|
|
is_python_module=False, |
|
|
) |
|
|
print("β grouped_gemm build successful") |
|
|
except Exception as e: |
|
|
print(f"β grouped_gemm build failed: {e}") |
|
|
|
|
|
signal.alarm(0) |
|
|
|