Kernels
optimizer / test /utils.py
wyldecat's picture
Support param group with various placements (#13)
e2b41e5 unverified
from dataclasses import dataclass
import torch
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.distributed.tensor.parallel import (ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
parallelize_module)
@dataclass
class ParallelDims:
dp_replicate_degree: int
dp_shard_degree: int
tp_degree: int
def __str__(self) -> str:
return (f"dp_replicate-{self.dp_replicate_degree}_"
f"dp_shard-{self.dp_shard_degree}_"
f"tp-{self.tp_degree}")
def _construct_device_mesh(parallel_dims: ParallelDims) -> DeviceMesh:
"""Constructs a DeviceMesh based on the given parallel dimensions.
Args:
parallel_dims (ParallelDims): The parallelism configuration.
Returns:
DeviceMesh: The constructed device mesh.
"""
world_size = dist.get_world_size()
expected_devices = (parallel_dims.dp_replicate_degree *
parallel_dims.dp_shard_degree *
parallel_dims.tp_degree)
if world_size < expected_devices:
raise ValueError(
f"Not enough devices: found {world_size}, "
f"but expected at least {expected_devices}. ({parallel_dims})")
degrees = [
parallel_dims.dp_replicate_degree, parallel_dims.dp_shard_degree,
parallel_dims.tp_degree
]
dim_names = ["dp_replicate", "dp_shard", "tp"]
mesh_shape = []
mesh_dim_names = []
for degree, dim_name in zip(degrees, dim_names):
if degree > 1:
mesh_shape.append(degree)
mesh_dim_names.append(dim_name)
device_mesh = dist.init_device_mesh("cuda",
mesh_shape,
mesh_dim_names=mesh_dim_names)
return device_mesh
def _apply_tp(
model: torch.nn.Module,
tp_mesh: DeviceMesh,
):
"""Apply tensor parallelism."""
# Layer names must match Motif model definition
# https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py
assert type(model).__name__ == "MotifForCausalLM"
# 1. Parallelize the embedding and shard its outputs (which are the first
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Parallelize the final linear output layer
parallelize_module(
model,
tp_mesh,
{
# This below separate tie_weights and make difficult to compare
# the answer with non-tensor-parallel version.
# TODO(jeesoo): check correctness for training semantic
#"model.embed_tokens":
#RowwiseParallel(
# input_layouts=Replicate(),
# output_layouts=Shard(1),
#),
"model.norm":
SequenceParallel(),
"output":
ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1), # loss_parallel
use_local_output=False,
),
},
)
# Apply tensor + sequence parallelism to every transformer block
for transformer_block in model.model.layers:
layer_plan = {
"input_layernorm":
SequenceParallel(),
"post_attention_layernorm":
SequenceParallel(),
"self_attn":
PrepareModuleInput(
# x, freqs_cis, attention_mask, position_ids, qk_clip
input_layouts=(Shard(1), Replicate(), None, None, None),
desired_input_layouts=(Replicate(), Replicate(), None, None,
None),
),
"self_attn.q_proj":
ColwiseParallel(),
"self_attn.k_proj":
ColwiseParallel(),
"self_attn.v_proj":
ColwiseParallel(),
"self_attn.o_proj":
RowwiseParallel(output_layouts=Shard(1)),
"mlp":
PrepareModuleInput(
input_layouts=(Shard(1), ),
desired_input_layouts=(Replicate(), ),
),
"mlp.gate_proj":
ColwiseParallel(),
"mlp.down_proj":
RowwiseParallel(output_layouts=Shard(1)),
"mlp.up_proj":
ColwiseParallel(),
}
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)
def _apply_fsdp(
model: torch.nn.Module,
dp_mesh: DeviceMesh,
):
for layer in model.model.layers:
fully_shard(layer, mesh=dp_mesh)
layer.reshard()
fully_shard(model, mesh=dp_mesh)
model.reshard()
def parallelize_motif(model: torch.nn.Module,
parallel_dims: ParallelDims) -> torch.nn.Module:
"""Parallelize the Motif model according to the given parallel dimensions.
Args:
model (torch.nn.Module): The Motif model to be parallelized.
parallel_dims (ParallelDims): The parallelism configuration.
Returns:
torch.nn.Module: The parallelized Motif model.
"""
mesh = _construct_device_mesh(parallel_dims)
if parallel_dims.tp_degree > 1:
_apply_tp(model, mesh["tp"])
if parallel_dims.dp_shard_degree > 1:
if parallel_dims.dp_replicate_degree > 1:
dp_dim_names = ("dp_replicate", "dp_shard")
else:
dp_dim_names = ("dp_shard", )
_apply_fsdp(model, mesh[dp_dim_names])
return model
def parallelize_qk_logits(
qk_logits: dict[int, torch.Tensor],
parallel_dims: ParallelDims,
) -> dict[int, torch.Tensor]:
"""Parallelize the QK logits according to the given parallel dimensions.
Args:
qk_logits (dict[int, torch.Tensor]): The QK logits to be parallelized.
parallel_dims (ParallelDims): The parallelism configuration.
Returns:
dict[int, torch.Tensor]: The parallelized QK logits.
"""
mesh = _construct_device_mesh(parallel_dims)
if parallel_dims.tp_degree > 1:
tp_rank = mesh["tp"].get_local_rank()
placements = [
Shard(0) if dim_name == "tp" else Replicate()
for dim_name in mesh.mesh_dim_names
]
for layer_idx, logits in qk_logits.items():
assert logits.size(0) % parallel_dims.tp_degree == 0
local_logits = logits.chunk(parallel_dims.tp_degree,
dim=0)[tp_rank].contiguous()
qk_logits[layer_idx] = DTensor.from_local(
local_tensor=local_logits,
device_mesh=mesh,
placements=placements,
)
return qk_logits
def assert_params_equal(actual: torch.nn.Module,
expected: torch.nn.Module) -> None:
"""Asserts that the parameters of two models are equal.
Args:
actual (torch.nn.Module): The actual model.
expected (torch.nn.Module): The expected model.
Returns:
None
"""
def get_full_param(param: torch.nn.Parameter) -> torch.Tensor:
if isinstance(param.data, DTensor):
return param.data.full_tensor()
return param.data
for (name_p, p), (name_s, s) in zip(actual.named_parameters(),
expected.named_parameters()):
p = get_full_param(p.cuda())
s = get_full_param(s.cuda())
torch.testing.assert_close(p, s, atol=0, rtol=0)