File size: 7,769 Bytes
e2b41e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
from dataclasses import dataclass
import torch
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.distributed.tensor.parallel import (ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
parallelize_module)
@dataclass
class ParallelDims:
dp_replicate_degree: int
dp_shard_degree: int
tp_degree: int
def __str__(self) -> str:
return (f"dp_replicate-{self.dp_replicate_degree}_"
f"dp_shard-{self.dp_shard_degree}_"
f"tp-{self.tp_degree}")
def _construct_device_mesh(parallel_dims: ParallelDims) -> DeviceMesh:
"""Constructs a DeviceMesh based on the given parallel dimensions.
Args:
parallel_dims (ParallelDims): The parallelism configuration.
Returns:
DeviceMesh: The constructed device mesh.
"""
world_size = dist.get_world_size()
expected_devices = (parallel_dims.dp_replicate_degree *
parallel_dims.dp_shard_degree *
parallel_dims.tp_degree)
if world_size < expected_devices:
raise ValueError(
f"Not enough devices: found {world_size}, "
f"but expected at least {expected_devices}. ({parallel_dims})")
degrees = [
parallel_dims.dp_replicate_degree, parallel_dims.dp_shard_degree,
parallel_dims.tp_degree
]
dim_names = ["dp_replicate", "dp_shard", "tp"]
mesh_shape = []
mesh_dim_names = []
for degree, dim_name in zip(degrees, dim_names):
if degree > 1:
mesh_shape.append(degree)
mesh_dim_names.append(dim_name)
device_mesh = dist.init_device_mesh("cuda",
mesh_shape,
mesh_dim_names=mesh_dim_names)
return device_mesh
def _apply_tp(
model: torch.nn.Module,
tp_mesh: DeviceMesh,
):
"""Apply tensor parallelism."""
# Layer names must match Motif model definition
# https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py
assert type(model).__name__ == "MotifForCausalLM"
# 1. Parallelize the embedding and shard its outputs (which are the first
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Parallelize the final linear output layer
parallelize_module(
model,
tp_mesh,
{
# This below separate tie_weights and make difficult to compare
# the answer with non-tensor-parallel version.
# TODO(jeesoo): check correctness for training semantic
#"model.embed_tokens":
#RowwiseParallel(
# input_layouts=Replicate(),
# output_layouts=Shard(1),
#),
"model.norm":
SequenceParallel(),
"output":
ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1), # loss_parallel
use_local_output=False,
),
},
)
# Apply tensor + sequence parallelism to every transformer block
for transformer_block in model.model.layers:
layer_plan = {
"input_layernorm":
SequenceParallel(),
"post_attention_layernorm":
SequenceParallel(),
"self_attn":
PrepareModuleInput(
# x, freqs_cis, attention_mask, position_ids, qk_clip
input_layouts=(Shard(1), Replicate(), None, None, None),
desired_input_layouts=(Replicate(), Replicate(), None, None,
None),
),
"self_attn.q_proj":
ColwiseParallel(),
"self_attn.k_proj":
ColwiseParallel(),
"self_attn.v_proj":
ColwiseParallel(),
"self_attn.o_proj":
RowwiseParallel(output_layouts=Shard(1)),
"mlp":
PrepareModuleInput(
input_layouts=(Shard(1), ),
desired_input_layouts=(Replicate(), ),
),
"mlp.gate_proj":
ColwiseParallel(),
"mlp.down_proj":
RowwiseParallel(output_layouts=Shard(1)),
"mlp.up_proj":
ColwiseParallel(),
}
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)
def _apply_fsdp(
model: torch.nn.Module,
dp_mesh: DeviceMesh,
):
for layer in model.model.layers:
fully_shard(layer, mesh=dp_mesh)
layer.reshard()
fully_shard(model, mesh=dp_mesh)
model.reshard()
def parallelize_motif(model: torch.nn.Module,
parallel_dims: ParallelDims) -> torch.nn.Module:
"""Parallelize the Motif model according to the given parallel dimensions.
Args:
model (torch.nn.Module): The Motif model to be parallelized.
parallel_dims (ParallelDims): The parallelism configuration.
Returns:
torch.nn.Module: The parallelized Motif model.
"""
mesh = _construct_device_mesh(parallel_dims)
if parallel_dims.tp_degree > 1:
_apply_tp(model, mesh["tp"])
if parallel_dims.dp_shard_degree > 1:
if parallel_dims.dp_replicate_degree > 1:
dp_dim_names = ("dp_replicate", "dp_shard")
else:
dp_dim_names = ("dp_shard", )
_apply_fsdp(model, mesh[dp_dim_names])
return model
def parallelize_qk_logits(
qk_logits: dict[int, torch.Tensor],
parallel_dims: ParallelDims,
) -> dict[int, torch.Tensor]:
"""Parallelize the QK logits according to the given parallel dimensions.
Args:
qk_logits (dict[int, torch.Tensor]): The QK logits to be parallelized.
parallel_dims (ParallelDims): The parallelism configuration.
Returns:
dict[int, torch.Tensor]: The parallelized QK logits.
"""
mesh = _construct_device_mesh(parallel_dims)
if parallel_dims.tp_degree > 1:
tp_rank = mesh["tp"].get_local_rank()
placements = [
Shard(0) if dim_name == "tp" else Replicate()
for dim_name in mesh.mesh_dim_names
]
for layer_idx, logits in qk_logits.items():
assert logits.size(0) % parallel_dims.tp_degree == 0
local_logits = logits.chunk(parallel_dims.tp_degree,
dim=0)[tp_rank].contiguous()
qk_logits[layer_idx] = DTensor.from_local(
local_tensor=local_logits,
device_mesh=mesh,
placements=placements,
)
return qk_logits
def assert_params_equal(actual: torch.nn.Module,
expected: torch.nn.Module) -> None:
"""Asserts that the parameters of two models are equal.
Args:
actual (torch.nn.Module): The actual model.
expected (torch.nn.Module): The expected model.
Returns:
None
"""
def get_full_param(param: torch.nn.Parameter) -> torch.Tensor:
if isinstance(param.data, DTensor):
return param.data.full_tensor()
return param.data
for (name_p, p), (name_s, s) in zip(actual.named_parameters(),
expected.named_parameters()):
p = get_full_param(p.cuda())
s = get_full_param(s.cuda())
torch.testing.assert_close(p, s, atol=0, rtol=0)
|