from dataclasses import dataclass import torch import torch.distributed as dist from torch.distributed.fsdp import fully_shard from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard from torch.distributed.tensor.parallel import (ColwiseParallel, PrepareModuleInput, RowwiseParallel, SequenceParallel, parallelize_module) @dataclass class ParallelDims: dp_replicate_degree: int dp_shard_degree: int tp_degree: int def __str__(self) -> str: return (f"dp_replicate-{self.dp_replicate_degree}_" f"dp_shard-{self.dp_shard_degree}_" f"tp-{self.tp_degree}") def _construct_device_mesh(parallel_dims: ParallelDims) -> DeviceMesh: """Constructs a DeviceMesh based on the given parallel dimensions. Args: parallel_dims (ParallelDims): The parallelism configuration. Returns: DeviceMesh: The constructed device mesh. """ world_size = dist.get_world_size() expected_devices = (parallel_dims.dp_replicate_degree * parallel_dims.dp_shard_degree * parallel_dims.tp_degree) if world_size < expected_devices: raise ValueError( f"Not enough devices: found {world_size}, " f"but expected at least {expected_devices}. ({parallel_dims})") degrees = [ parallel_dims.dp_replicate_degree, parallel_dims.dp_shard_degree, parallel_dims.tp_degree ] dim_names = ["dp_replicate", "dp_shard", "tp"] mesh_shape = [] mesh_dim_names = [] for degree, dim_name in zip(degrees, dim_names): if degree > 1: mesh_shape.append(degree) mesh_dim_names.append(dim_name) device_mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=mesh_dim_names) return device_mesh def _apply_tp( model: torch.nn.Module, tp_mesh: DeviceMesh, ): """Apply tensor parallelism.""" # Layer names must match Motif model definition # https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py assert type(model).__name__ == "MotifForCausalLM" # 1. Parallelize the embedding and shard its outputs (which are the first # transformer block's inputs) # 2. Parallelize the root norm layer over the sequence dim # 3. Parallelize the final linear output layer parallelize_module( model, tp_mesh, { # This below separate tie_weights and make difficult to compare # the answer with non-tensor-parallel version. # TODO(jeesoo): check correctness for training semantic #"model.embed_tokens": #RowwiseParallel( # input_layouts=Replicate(), # output_layouts=Shard(1), #), "model.norm": SequenceParallel(), "output": ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1), # loss_parallel use_local_output=False, ), }, ) # Apply tensor + sequence parallelism to every transformer block for transformer_block in model.model.layers: layer_plan = { "input_layernorm": SequenceParallel(), "post_attention_layernorm": SequenceParallel(), "self_attn": PrepareModuleInput( # x, freqs_cis, attention_mask, position_ids, qk_clip input_layouts=(Shard(1), Replicate(), None, None, None), desired_input_layouts=(Replicate(), Replicate(), None, None, None), ), "self_attn.q_proj": ColwiseParallel(), "self_attn.k_proj": ColwiseParallel(), "self_attn.v_proj": ColwiseParallel(), "self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), "mlp": PrepareModuleInput( input_layouts=(Shard(1), ), desired_input_layouts=(Replicate(), ), ), "mlp.gate_proj": ColwiseParallel(), "mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), "mlp.up_proj": ColwiseParallel(), } parallelize_module( module=transformer_block, device_mesh=tp_mesh, parallelize_plan=layer_plan, ) def _apply_fsdp( model: torch.nn.Module, dp_mesh: DeviceMesh, ): for layer in model.model.layers: fully_shard(layer, mesh=dp_mesh) layer.reshard() fully_shard(model, mesh=dp_mesh) model.reshard() def parallelize_motif(model: torch.nn.Module, parallel_dims: ParallelDims) -> torch.nn.Module: """Parallelize the Motif model according to the given parallel dimensions. Args: model (torch.nn.Module): The Motif model to be parallelized. parallel_dims (ParallelDims): The parallelism configuration. Returns: torch.nn.Module: The parallelized Motif model. """ mesh = _construct_device_mesh(parallel_dims) if parallel_dims.tp_degree > 1: _apply_tp(model, mesh["tp"]) if parallel_dims.dp_shard_degree > 1: if parallel_dims.dp_replicate_degree > 1: dp_dim_names = ("dp_replicate", "dp_shard") else: dp_dim_names = ("dp_shard", ) _apply_fsdp(model, mesh[dp_dim_names]) return model def parallelize_qk_logits( qk_logits: dict[int, torch.Tensor], parallel_dims: ParallelDims, ) -> dict[int, torch.Tensor]: """Parallelize the QK logits according to the given parallel dimensions. Args: qk_logits (dict[int, torch.Tensor]): The QK logits to be parallelized. parallel_dims (ParallelDims): The parallelism configuration. Returns: dict[int, torch.Tensor]: The parallelized QK logits. """ mesh = _construct_device_mesh(parallel_dims) if parallel_dims.tp_degree > 1: tp_rank = mesh["tp"].get_local_rank() placements = [ Shard(0) if dim_name == "tp" else Replicate() for dim_name in mesh.mesh_dim_names ] for layer_idx, logits in qk_logits.items(): assert logits.size(0) % parallel_dims.tp_degree == 0 local_logits = logits.chunk(parallel_dims.tp_degree, dim=0)[tp_rank].contiguous() qk_logits[layer_idx] = DTensor.from_local( local_tensor=local_logits, device_mesh=mesh, placements=placements, ) return qk_logits def assert_params_equal(actual: torch.nn.Module, expected: torch.nn.Module) -> None: """Asserts that the parameters of two models are equal. Args: actual (torch.nn.Module): The actual model. expected (torch.nn.Module): The expected model. Returns: None """ def get_full_param(param: torch.nn.Parameter) -> torch.Tensor: if isinstance(param.data, DTensor): return param.data.full_tensor() return param.data for (name_p, p), (name_s, s) in zip(actual.named_parameters(), expected.named_parameters()): p = get_full_param(p.cuda()) s = get_full_param(s.cuda()) torch.testing.assert_close(p, s, atol=0, rtol=0)