Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from datetime import timedelta | |
| def initialize_torch_distributed(): | |
| rank = int(os.getenv("RANK", "0")) | |
| world_size = int(os.getenv("WORLD_SIZE", "1")) | |
| if torch.cuda.is_available(): | |
| from torch.distributed import ProcessGroupNCCL | |
| # Set the device id. | |
| assert world_size <= torch.cuda.device_count(), "Each process is one gpu" | |
| device = rank % torch.cuda.device_count() | |
| torch.cuda.set_device(device) | |
| backend = "nccl" | |
| options = ProcessGroupNCCL.Options() | |
| options.is_high_priority_stream = True | |
| options._timeout = timedelta(seconds=60) | |
| else: | |
| backend = "gloo" | |
| options = None | |
| # Call the init process. | |
| torch.distributed.init_process_group( | |
| backend=backend, | |
| world_size=world_size, | |
| rank=rank, | |
| timeout=timedelta(seconds=60), | |
| pg_options=options, | |
| ) | |
| return torch.distributed.group.WORLD, rank, world_size | |