|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.distributed.fsdp import fully_shard |
|
|
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard |
|
|
from torch.distributed.tensor.parallel import (ColwiseParallel, |
|
|
PrepareModuleInput, |
|
|
RowwiseParallel, |
|
|
SequenceParallel, |
|
|
parallelize_module) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ParallelDims: |
|
|
dp_replicate_degree: int |
|
|
dp_shard_degree: int |
|
|
tp_degree: int |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return (f"dp_replicate-{self.dp_replicate_degree}_" |
|
|
f"dp_shard-{self.dp_shard_degree}_" |
|
|
f"tp-{self.tp_degree}") |
|
|
|
|
|
|
|
|
def _construct_device_mesh(parallel_dims: ParallelDims) -> DeviceMesh: |
|
|
"""Constructs a DeviceMesh based on the given parallel dimensions. |
|
|
|
|
|
Args: |
|
|
parallel_dims (ParallelDims): The parallelism configuration. |
|
|
|
|
|
Returns: |
|
|
DeviceMesh: The constructed device mesh. |
|
|
""" |
|
|
world_size = dist.get_world_size() |
|
|
expected_devices = (parallel_dims.dp_replicate_degree * |
|
|
parallel_dims.dp_shard_degree * |
|
|
parallel_dims.tp_degree) |
|
|
if world_size < expected_devices: |
|
|
raise ValueError( |
|
|
f"Not enough devices: found {world_size}, " |
|
|
f"but expected at least {expected_devices}. ({parallel_dims})") |
|
|
|
|
|
degrees = [ |
|
|
parallel_dims.dp_replicate_degree, parallel_dims.dp_shard_degree, |
|
|
parallel_dims.tp_degree |
|
|
] |
|
|
dim_names = ["dp_replicate", "dp_shard", "tp"] |
|
|
|
|
|
mesh_shape = [] |
|
|
mesh_dim_names = [] |
|
|
for degree, dim_name in zip(degrees, dim_names): |
|
|
if degree > 1: |
|
|
mesh_shape.append(degree) |
|
|
mesh_dim_names.append(dim_name) |
|
|
|
|
|
device_mesh = dist.init_device_mesh("cuda", |
|
|
mesh_shape, |
|
|
mesh_dim_names=mesh_dim_names) |
|
|
|
|
|
return device_mesh |
|
|
|
|
|
|
|
|
def _apply_tp( |
|
|
model: torch.nn.Module, |
|
|
tp_mesh: DeviceMesh, |
|
|
): |
|
|
"""Apply tensor parallelism.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert type(model).__name__ == "MotifForCausalLM" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parallelize_module( |
|
|
model, |
|
|
tp_mesh, |
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"model.norm": |
|
|
SequenceParallel(), |
|
|
"output": |
|
|
ColwiseParallel( |
|
|
input_layouts=Shard(1), |
|
|
output_layouts=Shard(-1), |
|
|
use_local_output=False, |
|
|
), |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
for transformer_block in model.model.layers: |
|
|
layer_plan = { |
|
|
"input_layernorm": |
|
|
SequenceParallel(), |
|
|
"post_attention_layernorm": |
|
|
SequenceParallel(), |
|
|
"self_attn": |
|
|
PrepareModuleInput( |
|
|
|
|
|
input_layouts=(Shard(1), Replicate(), None, None, None), |
|
|
desired_input_layouts=(Replicate(), Replicate(), None, None, |
|
|
None), |
|
|
), |
|
|
"self_attn.q_proj": |
|
|
ColwiseParallel(), |
|
|
"self_attn.k_proj": |
|
|
ColwiseParallel(), |
|
|
"self_attn.v_proj": |
|
|
ColwiseParallel(), |
|
|
"self_attn.o_proj": |
|
|
RowwiseParallel(output_layouts=Shard(1)), |
|
|
"mlp": |
|
|
PrepareModuleInput( |
|
|
input_layouts=(Shard(1), ), |
|
|
desired_input_layouts=(Replicate(), ), |
|
|
), |
|
|
"mlp.gate_proj": |
|
|
ColwiseParallel(), |
|
|
"mlp.down_proj": |
|
|
RowwiseParallel(output_layouts=Shard(1)), |
|
|
"mlp.up_proj": |
|
|
ColwiseParallel(), |
|
|
} |
|
|
|
|
|
parallelize_module( |
|
|
module=transformer_block, |
|
|
device_mesh=tp_mesh, |
|
|
parallelize_plan=layer_plan, |
|
|
) |
|
|
|
|
|
|
|
|
def _apply_fsdp( |
|
|
model: torch.nn.Module, |
|
|
dp_mesh: DeviceMesh, |
|
|
): |
|
|
for layer in model.model.layers: |
|
|
fully_shard(layer, mesh=dp_mesh) |
|
|
layer.reshard() |
|
|
fully_shard(model, mesh=dp_mesh) |
|
|
model.reshard() |
|
|
|
|
|
|
|
|
def parallelize_motif(model: torch.nn.Module, |
|
|
parallel_dims: ParallelDims) -> torch.nn.Module: |
|
|
"""Parallelize the Motif model according to the given parallel dimensions. |
|
|
|
|
|
Args: |
|
|
model (torch.nn.Module): The Motif model to be parallelized. |
|
|
parallel_dims (ParallelDims): The parallelism configuration. |
|
|
|
|
|
Returns: |
|
|
torch.nn.Module: The parallelized Motif model. |
|
|
""" |
|
|
|
|
|
mesh = _construct_device_mesh(parallel_dims) |
|
|
|
|
|
if parallel_dims.tp_degree > 1: |
|
|
_apply_tp(model, mesh["tp"]) |
|
|
|
|
|
if parallel_dims.dp_shard_degree > 1: |
|
|
if parallel_dims.dp_replicate_degree > 1: |
|
|
dp_dim_names = ("dp_replicate", "dp_shard") |
|
|
else: |
|
|
dp_dim_names = ("dp_shard", ) |
|
|
_apply_fsdp(model, mesh[dp_dim_names]) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def parallelize_qk_logits( |
|
|
qk_logits: dict[int, torch.Tensor], |
|
|
parallel_dims: ParallelDims, |
|
|
) -> dict[int, torch.Tensor]: |
|
|
"""Parallelize the QK logits according to the given parallel dimensions. |
|
|
|
|
|
Args: |
|
|
qk_logits (dict[int, torch.Tensor]): The QK logits to be parallelized. |
|
|
parallel_dims (ParallelDims): The parallelism configuration. |
|
|
|
|
|
Returns: |
|
|
dict[int, torch.Tensor]: The parallelized QK logits. |
|
|
""" |
|
|
|
|
|
mesh = _construct_device_mesh(parallel_dims) |
|
|
|
|
|
if parallel_dims.tp_degree > 1: |
|
|
tp_rank = mesh["tp"].get_local_rank() |
|
|
placements = [ |
|
|
Shard(0) if dim_name == "tp" else Replicate() |
|
|
for dim_name in mesh.mesh_dim_names |
|
|
] |
|
|
for layer_idx, logits in qk_logits.items(): |
|
|
assert logits.size(0) % parallel_dims.tp_degree == 0 |
|
|
local_logits = logits.chunk(parallel_dims.tp_degree, |
|
|
dim=0)[tp_rank].contiguous() |
|
|
|
|
|
qk_logits[layer_idx] = DTensor.from_local( |
|
|
local_tensor=local_logits, |
|
|
device_mesh=mesh, |
|
|
placements=placements, |
|
|
) |
|
|
|
|
|
return qk_logits |
|
|
|
|
|
|
|
|
def assert_params_equal(actual: torch.nn.Module, |
|
|
expected: torch.nn.Module) -> None: |
|
|
"""Asserts that the parameters of two models are equal. |
|
|
|
|
|
Args: |
|
|
actual (torch.nn.Module): The actual model. |
|
|
expected (torch.nn.Module): The expected model. |
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
|
|
|
def get_full_param(param: torch.nn.Parameter) -> torch.Tensor: |
|
|
if isinstance(param.data, DTensor): |
|
|
return param.data.full_tensor() |
|
|
return param.data |
|
|
|
|
|
for (name_p, p), (name_s, s) in zip(actual.named_parameters(), |
|
|
expected.named_parameters()): |
|
|
p = get_full_param(p.cuda()) |
|
|
s = get_full_param(s.cuda()) |
|
|
|
|
|
torch.testing.assert_close(p, s, atol=0, rtol=0) |
|
|
|