|
|
import logging |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from packaging import version |
|
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
SEED = 0xdeadbeef |
|
|
|
|
|
|
|
|
def pytest_addoption(parser): |
|
|
parser.addoption( |
|
|
"--measure-perf", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help= |
|
|
"Measure execution time and peak memory usage during optimizer step.", |
|
|
) |
|
|
|
|
|
parser.addoption( |
|
|
"--do-profile", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Enable profiling during tests.", |
|
|
) |
|
|
|
|
|
parser.addoption( |
|
|
"--skip-verify", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help= |
|
|
"Skip verification of optimizer step correctness with sequential implementation.\n" |
|
|
"This can be useful when GPU memory is limited.", |
|
|
) |
|
|
|
|
|
|
|
|
def pytest_configure(config): |
|
|
if config.getoption( |
|
|
"--do-profile") and not config.getoption("--measure-perf"): |
|
|
raise pytest.UsageError( |
|
|
"--do-profile requires --measure-perf. Please enable both flags.") |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session") |
|
|
def measure_perf(request): |
|
|
return request.config.getoption("--measure-perf") |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session") |
|
|
def do_profile(request): |
|
|
return request.config.getoption("--do-profile") |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session") |
|
|
def skip_verify(request): |
|
|
return request.config.getoption("--skip-verify") |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True) |
|
|
def init_dist(request): |
|
|
if version.parse(torch.__version__) < version.parse("2.8"): |
|
|
pytest.skip("torch>=2.8.0 is required for parallel muon") |
|
|
return |
|
|
|
|
|
try: |
|
|
dist.init_process_group(backend="nccl") |
|
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) |
|
|
except Exception as e: |
|
|
print(f"Failed to initialize torch.distributed: {e}") |
|
|
pytest.skip("Failed to initialize torch.distributed") |
|
|
|
|
|
if dist.get_world_size() != 8: |
|
|
pytest.skip("Need 8 processes in dist group. " |
|
|
"You can run with `torchrun --nproc-per-node=8 " |
|
|
"--local-ranks-filter 0 -m pytest " |
|
|
"test_rms_norm_sequence_parallel.py`." |
|
|
"To run with less than 8 gpus, modify " |
|
|
"the test cases accordingly.") |
|
|
|
|
|
yield |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session") |
|
|
def inputs(): |
|
|
"""Load Motif-2.6B model and generate random gradients for testing. |
|
|
Returns: |
|
|
tuple[torch.nn.Module, list[torch.Tensor], dict[int, torch.Tensor]]: |
|
|
- torch.nn.Module: The Motif-2.6B model. |
|
|
- list[torch.Tensor]: A list of random gradients for each model parameter. |
|
|
- dict[int, torch.Tensor]: A dictionary mapping layer indices to random QK logits. |
|
|
""" |
|
|
model_name = "Motif-Technologies/Motif-2.6B-4layer-random" |
|
|
|
|
|
torch.manual_seed(SEED) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(SEED) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
logger.info( |
|
|
f"Loaded model {model_name}. ({len(list(model.parameters()))} parameters)" |
|
|
) |
|
|
|
|
|
grads: list[torch.Tensor] = [] |
|
|
for param in model.parameters(): |
|
|
grad = torch.randn_like(param, device=param.device, dtype=param.dtype) |
|
|
grads.append(grad) |
|
|
|
|
|
qk_logits: dict[int, torch.Tensor] = { |
|
|
i: |
|
|
torch.randn(model.config.num_attention_heads, |
|
|
device=model.device, |
|
|
dtype=torch.bfloat16) |
|
|
for i in range(model.config.num_hidden_layers) |
|
|
} |
|
|
|
|
|
return [model, grads, qk_logits] |
|
|
|