Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| # Code are based on | |
| # https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Copyright (c) Megvii, Inc. and its affiliates. | |
| import sys | |
| from datetime import timedelta | |
| from loguru import logger | |
| import torch | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| import yolox.utils.dist as comm | |
| __all__ = ["launch"] | |
| DEFAULT_TIMEOUT = timedelta(minutes=30) | |
| def _find_free_port(): | |
| """ | |
| Find an available port of current machine / node. | |
| """ | |
| import socket | |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| # Binding to port 0 will cause the OS to find an available port for us | |
| sock.bind(("", 0)) | |
| port = sock.getsockname()[1] | |
| sock.close() | |
| # NOTE: there is still a chance the port could be taken by other processes. | |
| return port | |
| def launch( | |
| main_func, | |
| num_gpus_per_machine, | |
| num_machines=1, | |
| machine_rank=0, | |
| backend="nccl", | |
| dist_url=None, | |
| args=(), | |
| timeout=DEFAULT_TIMEOUT, | |
| ): | |
| """ | |
| Args: | |
| main_func: a function that will be called by `main_func(*args)` | |
| num_machines (int): the total number of machines | |
| machine_rank (int): the rank of this machine (one per machine) | |
| dist_url (str): url to connect to for distributed training, including protocol | |
| e.g. "tcp://127.0.0.1:8686". | |
| Can be set to auto to automatically select a free port on localhost | |
| args (tuple): arguments passed to main_func | |
| """ | |
| world_size = num_machines * num_gpus_per_machine | |
| if world_size > 1: | |
| # https://github.com/pytorch/pytorch/pull/14391 | |
| # TODO prctl in spawned processes | |
| if dist_url == "auto": | |
| assert ( | |
| num_machines == 1 | |
| ), "dist_url=auto cannot work with distributed training." | |
| port = _find_free_port() | |
| dist_url = f"tcp://127.0.0.1:{port}" | |
| start_method = "spawn" | |
| cache = vars(args[1]).get("cache", False) | |
| # To use numpy memmap for caching image into RAM, we have to use fork method | |
| if cache: | |
| assert sys.platform != "win32", ( | |
| "As Windows platform doesn't support fork method, " | |
| "do not add --cache in your training command." | |
| ) | |
| start_method = "fork" | |
| mp.start_processes( | |
| _distributed_worker, | |
| nprocs=num_gpus_per_machine, | |
| args=( | |
| main_func, | |
| world_size, | |
| num_gpus_per_machine, | |
| machine_rank, | |
| backend, | |
| dist_url, | |
| args, | |
| ), | |
| daemon=False, | |
| start_method=start_method, | |
| ) | |
| else: | |
| main_func(*args) | |
| def _distributed_worker( | |
| local_rank, | |
| main_func, | |
| world_size, | |
| num_gpus_per_machine, | |
| machine_rank, | |
| backend, | |
| dist_url, | |
| args, | |
| timeout=DEFAULT_TIMEOUT, | |
| ): | |
| assert ( | |
| torch.cuda.is_available() | |
| ), "cuda is not available. Please check your installation." | |
| global_rank = machine_rank * num_gpus_per_machine + local_rank | |
| logger.info("Rank {} initialization finished.".format(global_rank)) | |
| try: | |
| dist.init_process_group( | |
| backend=backend, | |
| init_method=dist_url, | |
| world_size=world_size, | |
| rank=global_rank, | |
| timeout=timeout, | |
| ) | |
| except Exception: | |
| logger.error("Process group URL: {}".format(dist_url)) | |
| raise | |
| # Setup the local process group (which contains ranks within the same machine) | |
| assert comm._LOCAL_PROCESS_GROUP is None | |
| num_machines = world_size // num_gpus_per_machine | |
| for i in range(num_machines): | |
| ranks_on_i = list( | |
| range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) | |
| ) | |
| pg = dist.new_group(ranks_on_i) | |
| if i == machine_rank: | |
| comm._LOCAL_PROCESS_GROUP = pg | |
| # synchronize is needed here to prevent a possible timeout after calling init_process_group | |
| # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 | |
| comm.synchronize() | |
| assert num_gpus_per_machine <= torch.cuda.device_count() | |
| torch.cuda.set_device(local_rank) | |
| main_func(*args) | |