|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Distributed basic functions. |
|
|
""" |
|
|
|
|
|
import os |
|
|
from datetime import timedelta |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
|
|
|
|
|
def get_global_rank() -> int: |
|
|
""" |
|
|
Get the global rank, the global index of the GPU. |
|
|
""" |
|
|
return int(os.environ.get("RANK", "0")) |
|
|
|
|
|
|
|
|
def get_local_rank() -> int: |
|
|
""" |
|
|
Get the local rank, the local index of the GPU. |
|
|
""" |
|
|
return int(os.environ.get("LOCAL_RANK", "0")) |
|
|
|
|
|
|
|
|
def get_world_size() -> int: |
|
|
""" |
|
|
Get the world size, the total amount of GPUs. |
|
|
""" |
|
|
return int(os.environ.get("WORLD_SIZE", "1")) |
|
|
|
|
|
|
|
|
def get_device() -> torch.device: |
|
|
""" |
|
|
Get current rank device. |
|
|
""" |
|
|
return torch.device("cuda", get_local_rank()) |
|
|
|
|
|
|
|
|
def barrier_if_distributed(*args, **kwargs): |
|
|
""" |
|
|
Synchronizes all processes if under distributed context. |
|
|
""" |
|
|
if dist.is_initialized(): |
|
|
return dist.barrier(*args, **kwargs) |
|
|
|
|
|
|
|
|
def init_torch(cudnn_benchmark=True, timeout=timedelta(seconds=600)): |
|
|
""" |
|
|
Common PyTorch initialization configuration. |
|
|
""" |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
torch.backends.cudnn.benchmark = cudnn_benchmark |
|
|
torch.cuda.set_device(get_local_rank()) |
|
|
dist.init_process_group( |
|
|
backend="nccl", |
|
|
rank=get_global_rank(), |
|
|
world_size=get_world_size(), |
|
|
timeout=timeout, |
|
|
) |
|
|
|
|
|
|
|
|
def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel: |
|
|
return DistributedDataParallel( |
|
|
module=module, |
|
|
device_ids=[get_local_rank()], |
|
|
output_device=get_local_rank(), |
|
|
**kwargs, |
|
|
) |
|
|
|