|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Advanced distributed functions for sequence parallel. |
|
|
""" |
|
|
|
|
|
from typing import Optional, List |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
|
|
from torch.distributed.fsdp import ShardingStrategy |
|
|
|
|
|
from .basic import get_global_rank, get_world_size |
|
|
|
|
|
|
|
|
_DATA_PARALLEL_GROUP = None |
|
|
_SEQUENCE_PARALLEL_GROUP = None |
|
|
_SEQUENCE_PARALLEL_CPU_GROUP = None |
|
|
_MODEL_SHARD_CPU_INTER_GROUP = None |
|
|
_MODEL_SHARD_CPU_INTRA_GROUP = None |
|
|
_MODEL_SHARD_INTER_GROUP = None |
|
|
_MODEL_SHARD_INTRA_GROUP = None |
|
|
_SEQUENCE_PARALLEL_GLOBAL_RANKS = None |
|
|
|
|
|
|
|
|
def get_data_parallel_group() -> Optional[dist.ProcessGroup]: |
|
|
""" |
|
|
Get data parallel process group. |
|
|
""" |
|
|
return _DATA_PARALLEL_GROUP |
|
|
|
|
|
|
|
|
def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]: |
|
|
""" |
|
|
Get sequence parallel process group. |
|
|
""" |
|
|
return _SEQUENCE_PARALLEL_GROUP |
|
|
|
|
|
|
|
|
def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]: |
|
|
""" |
|
|
Get sequence parallel CPU process group. |
|
|
""" |
|
|
return _SEQUENCE_PARALLEL_CPU_GROUP |
|
|
|
|
|
|
|
|
def get_data_parallel_rank() -> int: |
|
|
""" |
|
|
Get data parallel rank. |
|
|
""" |
|
|
group = get_data_parallel_group() |
|
|
return dist.get_rank(group) if group else get_global_rank() |
|
|
|
|
|
|
|
|
def get_data_parallel_world_size() -> int: |
|
|
""" |
|
|
Get data parallel world size. |
|
|
""" |
|
|
group = get_data_parallel_group() |
|
|
return dist.get_world_size(group) if group else get_world_size() |
|
|
|
|
|
|
|
|
def get_sequence_parallel_rank() -> int: |
|
|
""" |
|
|
Get sequence parallel rank. |
|
|
""" |
|
|
group = get_sequence_parallel_group() |
|
|
return dist.get_rank(group) if group else 0 |
|
|
|
|
|
|
|
|
def get_sequence_parallel_world_size() -> int: |
|
|
""" |
|
|
Get sequence parallel world size. |
|
|
""" |
|
|
group = get_sequence_parallel_group() |
|
|
return dist.get_world_size(group) if group else 1 |
|
|
|
|
|
|
|
|
def get_model_shard_cpu_intra_group() -> Optional[dist.ProcessGroup]: |
|
|
""" |
|
|
Get the CPU intra process group of model sharding. |
|
|
""" |
|
|
return _MODEL_SHARD_CPU_INTRA_GROUP |
|
|
|
|
|
|
|
|
def get_model_shard_cpu_inter_group() -> Optional[dist.ProcessGroup]: |
|
|
""" |
|
|
Get the CPU inter process group of model sharding. |
|
|
""" |
|
|
return _MODEL_SHARD_CPU_INTER_GROUP |
|
|
|
|
|
|
|
|
def get_model_shard_intra_group() -> Optional[dist.ProcessGroup]: |
|
|
""" |
|
|
Get the GPU intra process group of model sharding. |
|
|
""" |
|
|
return _MODEL_SHARD_INTRA_GROUP |
|
|
|
|
|
|
|
|
def get_model_shard_inter_group() -> Optional[dist.ProcessGroup]: |
|
|
""" |
|
|
Get the GPU inter process group of model sharding. |
|
|
""" |
|
|
return _MODEL_SHARD_INTER_GROUP |
|
|
|
|
|
|
|
|
def init_sequence_parallel(sequence_parallel_size: int): |
|
|
""" |
|
|
Initialize sequence parallel. |
|
|
""" |
|
|
global _DATA_PARALLEL_GROUP |
|
|
global _SEQUENCE_PARALLEL_GROUP |
|
|
global _SEQUENCE_PARALLEL_CPU_GROUP |
|
|
global _SEQUENCE_PARALLEL_GLOBAL_RANKS |
|
|
assert dist.is_initialized() |
|
|
world_size = dist.get_world_size() |
|
|
rank = dist.get_rank() |
|
|
data_parallel_size = world_size // sequence_parallel_size |
|
|
for i in range(data_parallel_size): |
|
|
start_rank = i * sequence_parallel_size |
|
|
end_rank = (i + 1) * sequence_parallel_size |
|
|
ranks = range(start_rank, end_rank) |
|
|
group = dist.new_group(ranks) |
|
|
cpu_group = dist.new_group(ranks, backend="gloo") |
|
|
if rank in ranks: |
|
|
_SEQUENCE_PARALLEL_GROUP = group |
|
|
_SEQUENCE_PARALLEL_CPU_GROUP = cpu_group |
|
|
_SEQUENCE_PARALLEL_GLOBAL_RANKS = list(ranks) |
|
|
|
|
|
|
|
|
def init_model_shard_group( |
|
|
*, |
|
|
sharding_strategy: ShardingStrategy, |
|
|
device_mesh: Optional[DeviceMesh] = None, |
|
|
): |
|
|
""" |
|
|
Initialize process group of model sharding. |
|
|
""" |
|
|
global _MODEL_SHARD_INTER_GROUP |
|
|
global _MODEL_SHARD_INTRA_GROUP |
|
|
global _MODEL_SHARD_CPU_INTER_GROUP |
|
|
global _MODEL_SHARD_CPU_INTRA_GROUP |
|
|
assert dist.is_initialized() |
|
|
world_size = dist.get_world_size() |
|
|
if device_mesh is not None: |
|
|
num_shards_per_group = device_mesh.shape[1] |
|
|
elif sharding_strategy == ShardingStrategy.NO_SHARD: |
|
|
num_shards_per_group = 1 |
|
|
elif sharding_strategy in [ |
|
|
ShardingStrategy.HYBRID_SHARD, |
|
|
ShardingStrategy._HYBRID_SHARD_ZERO2, |
|
|
]: |
|
|
num_shards_per_group = torch.cuda.device_count() |
|
|
else: |
|
|
num_shards_per_group = world_size |
|
|
num_groups = world_size // num_shards_per_group |
|
|
device_mesh = (num_groups, num_shards_per_group) |
|
|
|
|
|
gpu_mesh_2d = init_device_mesh("cuda", device_mesh, mesh_dim_names=("inter", "intra")) |
|
|
cpu_mesh_2d = init_device_mesh("cpu", device_mesh, mesh_dim_names=("inter", "intra")) |
|
|
|
|
|
_MODEL_SHARD_INTER_GROUP = gpu_mesh_2d.get_group("inter") |
|
|
_MODEL_SHARD_INTRA_GROUP = gpu_mesh_2d.get_group("intra") |
|
|
_MODEL_SHARD_CPU_INTER_GROUP = cpu_mesh_2d.get_group("inter") |
|
|
_MODEL_SHARD_CPU_INTRA_GROUP = cpu_mesh_2d.get_group("intra") |
|
|
|
|
|
def get_sequence_parallel_global_ranks() -> List[int]: |
|
|
""" |
|
|
Get all global ranks of the sequence parallel process group |
|
|
that the caller rank belongs to. |
|
|
""" |
|
|
if _SEQUENCE_PARALLEL_GLOBAL_RANKS is None: |
|
|
return [dist.get_rank()] |
|
|
return _SEQUENCE_PARALLEL_GLOBAL_RANKS |
|
|
|
|
|
|
|
|
def get_next_sequence_parallel_rank() -> int: |
|
|
""" |
|
|
Get the next global rank of the sequence parallel process group |
|
|
that the caller rank belongs to. |
|
|
""" |
|
|
sp_global_ranks = get_sequence_parallel_global_ranks() |
|
|
sp_rank = get_sequence_parallel_rank() |
|
|
sp_size = get_sequence_parallel_world_size() |
|
|
return sp_global_ranks[(sp_rank + 1) % sp_size] |
|
|
|
|
|
|
|
|
def get_prev_sequence_parallel_rank() -> int: |
|
|
""" |
|
|
Get the previous global rank of the sequence parallel process group |
|
|
that the caller rank belongs to. |
|
|
""" |
|
|
sp_global_ranks = get_sequence_parallel_global_ranks() |
|
|
sp_rank = get_sequence_parallel_rank() |
|
|
sp_size = get_sequence_parallel_world_size() |
|
|
return sp_global_ranks[(sp_rank + sp_size - 1) % sp_size] |