| from typing import Optional | |
| import torch.distributed as dist | |
| def get_rank(group: Optional[dist.ProcessGroup] = None): | |
| return dist.get_rank(group) if dist.is_initialized() else 0 | |
| def get_world_size(group: Optional[dist.ProcessGroup] = None): | |
| return dist.get_world_size(group) if dist.is_initialized() else 1 | |
| def barrier(group: Optional[dist.ProcessGroup] = None): | |
| if dist.is_initialized(): | |
| dist.barrier(group) | |
| class rank_gate: | |
| ''' | |
| Execute the function on rank 0 first, followed by all other ranks. Useful when caches may need to be populated in a distributed environment. | |
| ''' | |
| def __init__(self, func = None): | |
| self.func = func | |
| def __call__(self, *args, **kwargs): | |
| rank = get_rank() | |
| if rank == 0: | |
| result = self.func(*args, **kwargs) | |
| barrier() | |
| if rank > 0: | |
| result = self.func(*args, **kwargs) | |
| return result | |
| def __enter__(self, *args, **kwargs): | |
| if get_rank() > 0: | |
| barrier() | |
| def __exit__(self, *args, **kwargs): | |
| if get_rank() == 0: | |
| barrier() | |