|  | #!/usr/bin/env bash | 
					
						
						|  | set -euo pipefail | 
					
						
						|  |  | 
					
						
						|  | KERNEL_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) | 
					
						
						|  | cd "$KERNEL_DIR" | 
					
						
						|  |  | 
					
						
						|  | export KERNEL_DIR | 
					
						
						|  |  | 
					
						
						|  | detect_variant() { | 
					
						
						|  | python - <<'PY' | 
					
						
						|  | import os | 
					
						
						|  | import pathlib | 
					
						
						|  |  | 
					
						
						|  | root = pathlib.Path(os.environ["KERNEL_DIR"]) | 
					
						
						|  | build_dir = root / "build" | 
					
						
						|  | variant = None | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | from kernels.utils import build_variant as _build_variant | 
					
						
						|  | except Exception: | 
					
						
						|  | _build_variant = None | 
					
						
						|  |  | 
					
						
						|  | if _build_variant is not None: | 
					
						
						|  | try: | 
					
						
						|  | variant = _build_variant() | 
					
						
						|  | except Exception: | 
					
						
						|  | variant = None | 
					
						
						|  |  | 
					
						
						|  | if variant is None: | 
					
						
						|  | candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*")) | 
					
						
						|  | if candidates: | 
					
						
						|  | variant = candidates[0].name | 
					
						
						|  |  | 
					
						
						|  | if variant is None: | 
					
						
						|  | raise SystemExit("Could not determine MegaBlocks build variant. Run build.py first.") | 
					
						
						|  |  | 
					
						
						|  | print(variant) | 
					
						
						|  | PY | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | VARIANT=$(detect_variant) | 
					
						
						|  |  | 
					
						
						|  | STAGED_DIR="$KERNEL_DIR/build/$VARIANT" | 
					
						
						|  | find_staged_lib() { | 
					
						
						|  | local base="$1" | 
					
						
						|  | local candidates=( | 
					
						
						|  | "$base/_megablocks_rocm.so" | 
					
						
						|  | "$base/megablocks/_megablocks_rocm.so" | 
					
						
						|  | ) | 
					
						
						|  | for path in "${candidates[@]}"; do | 
					
						
						|  | if [[ -f "$path" ]]; then | 
					
						
						|  | echo "$path" | 
					
						
						|  | return 0 | 
					
						
						|  | fi | 
					
						
						|  | done | 
					
						
						|  | return 1 | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | STAGED_LIB=$(find_staged_lib "$STAGED_DIR") || true | 
					
						
						|  |  | 
					
						
						|  | if [[ -z "${STAGED_LIB:-}" ]]; then | 
					
						
						|  | echo "Staged ROCm extension not found under $STAGED_DIR; rebuilding kernels..." | 
					
						
						|  | python build.py | 
					
						
						|  | VARIANT=$(detect_variant) | 
					
						
						|  | STAGED_DIR="$KERNEL_DIR/build/$VARIANT" | 
					
						
						|  | STAGED_LIB=$(find_staged_lib "$STAGED_DIR") || true | 
					
						
						|  | if [[ -z "${STAGED_LIB:-}" ]]; then | 
					
						
						|  | echo "ERROR: build.py completed but no extension was found under $STAGED_DIR" >&2 | 
					
						
						|  | exit 1 | 
					
						
						|  | fi | 
					
						
						|  | fi | 
					
						
						|  |  | 
					
						
						|  | export PYTHONPATH="$STAGED_DIR:${PYTHONPATH:-}" | 
					
						
						|  |  | 
					
						
						|  | echo "Using MegaBlocks build variant: $VARIANT" | 
					
						
						|  |  | 
					
						
						|  | declare -i GPU_COUNT | 
					
						
						|  | GPU_COUNT=$(python - <<'PY' | 
					
						
						|  | import torch | 
					
						
						|  | print(torch.cuda.device_count() if torch.cuda.is_available() else 0) | 
					
						
						|  | PY | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if (( GPU_COUNT == 0 )); then | 
					
						
						|  | echo "ERROR: No HIP/CUDA GPUs detected. Tests require at least one visible accelerator." >&2 | 
					
						
						|  | exit 1 | 
					
						
						|  | fi | 
					
						
						|  |  | 
					
						
						|  | echo "Detected $GPU_COUNT visible GPU(s)." | 
					
						
						|  |  | 
					
						
						|  | log() { | 
					
						
						|  | echo | 
					
						
						|  | echo "==> $1" | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | run_pytest() { | 
					
						
						|  | local label="$1" | 
					
						
						|  | shift | 
					
						
						|  | log "$label" | 
					
						
						|  | set -x | 
					
						
						|  | "$@" | 
					
						
						|  | { set +x; } 2>/dev/null || true | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | SINGLE_GPU_ENV=(HIP_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 WORLD_SIZE=1) | 
					
						
						|  | MULTI2_GPU_ENV=(HIP_VISIBLE_DEVICES=0,1 CUDA_VISIBLE_DEVICES=0,1 WORLD_SIZE=2) | 
					
						
						|  | MULTI8_GPU_ENV=(HIP_VISIBLE_DEVICES=$(seq -s, 0 7) CUDA_VISIBLE_DEVICES=$(seq -s, 0 7) WORLD_SIZE=8) | 
					
						
						|  |  | 
					
						
						|  | SINGLE_TESTS=( | 
					
						
						|  | "test_mb_moe.py" | 
					
						
						|  | "test_mb_moe_shared_expert.py" | 
					
						
						|  | "layer_test.py" | 
					
						
						|  | "test_gg.py" | 
					
						
						|  | "ops_test.py" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | for test in "${SINGLE_TESTS[@]}"; do | 
					
						
						|  | run_pytest "Single-GPU pytest ${test}" env "${SINGLE_GPU_ENV[@]}" python -m pytest "tests/${test}" -q | 
					
						
						|  | done | 
					
						
						|  |  | 
					
						
						|  | if (( GPU_COUNT >= 2 )); then | 
					
						
						|  | run_pytest "Distributed layer smoke (2 GPUs)" env "${MULTI2_GPU_ENV[@]}" python -m pytest "tests/parallel_layer_test.py::test_megablocks_moe_mlp_functionality" -q | 
					
						
						|  | else | 
					
						
						|  | log "Skipping 2-GPU distributed layer test (requires >=2 GPUs, detected ${GPU_COUNT})." | 
					
						
						|  | fi | 
					
						
						|  |  | 
					
						
						|  | run_pytest "Shared expert functionality (world_size=1)" env "${SINGLE_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_functionality[1]' -q | 
					
						
						|  | run_pytest "Shared expert weighted sum (world_size=1)" env "${SINGLE_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_weighted_sum[1]' -q | 
					
						
						|  |  | 
					
						
						|  | if (( GPU_COUNT >= 8 )); then | 
					
						
						|  | run_pytest "Shared expert functionality (world_size=8)" env "${MULTI8_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_functionality[8]' -q | 
					
						
						|  | run_pytest "Shared expert weighted sum (world_size=8)" env "${MULTI8_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_weighted_sum[8]' -q | 
					
						
						|  | else | 
					
						
						|  | log "Skipping 8-GPU shared expert tests (requires >=8 GPUs, detected ${GPU_COUNT})." | 
					
						
						|  | fi | 
					
						
						|  |  | 
					
						
						|  | echo | 
					
						
						|  | echo "All requested tests completed." | 
					
						
						|  |  |