Ovi / ovi /distributed_comms /parallel_states.py
alexnasa's picture
Upload 121 files
a3a2e41 verified
import os
import torch.distributed as dist
class COMM_INFO:
def __init__(self):
self.group = None
self.sp_size = 1
self.global_rank = 0
self.rank_within_group = 0
self.group_id = 0
nccl_info = COMM_INFO()
_SEQUENCE_PARALLEL_STATE = False
def initialize_sequence_parallel_state(sequence_parallel_size):
global _SEQUENCE_PARALLEL_STATE
if sequence_parallel_size > 1:
_SEQUENCE_PARALLEL_STATE = True
initialize_sequence_parallel_group(sequence_parallel_size)
else:
nccl_info.sp_size = 1
nccl_info.global_rank = int(os.getenv("RANK", "0"))
nccl_info.rank_within_group = 0
nccl_info.group_id = int(os.getenv("RANK", "0"))
def set_sequence_parallel_state(state):
global _SEQUENCE_PARALLEL_STATE
_SEQUENCE_PARALLEL_STATE = state
def get_sequence_parallel_state():
return _SEQUENCE_PARALLEL_STATE
def initialize_sequence_parallel_group(sequence_parallel_size):
"""Initialize the sequence parallel group."""
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
assert (
world_size % sequence_parallel_size == 0
), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format(
world_size, sequence_parallel_size)
nccl_info.sp_size = sequence_parallel_size
nccl_info.global_rank = rank
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
for i in range(num_sequence_parallel_groups):
ranks = range(i * sequence_parallel_size,
(i + 1) * sequence_parallel_size)
group = dist.new_group(ranks)
if rank in ranks:
nccl_info.group = group
nccl_info.rank_within_group = rank - i * sequence_parallel_size
nccl_info.group_id = i
def initialize_sequence_parallel_group_custom(process_group):
set_sequence_parallel_state(True)
"""Initialize an unsafe sequence parallel group with a pre-formed group."""
rank = dist.get_rank(group=process_group)
sequence_parallel_size = dist.get_world_size(group=process_group)
nccl_info.sp_size = sequence_parallel_size
nccl_info.global_rank = dist.get_rank() # global rank
nccl_info.group = process_group
nccl_info.rank_within_group = rank
nccl_info.group_id = 0
def destroy_sequence_parallel_group():
"""Destroy the sequence parallel group."""
dist.destroy_process_group()