Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| import logging | |
| import os | |
| import pickle | |
| import random | |
| import socket | |
| import struct | |
| import subprocess | |
| import warnings | |
| import tempfile | |
| import uuid | |
| from datetime import date | |
| from pathlib import Path | |
| from collections import OrderedDict | |
| from typing import Any, Dict, Mapping | |
| import torch | |
| import torch.distributed as dist | |
| logger = logging.getLogger(__name__) | |
| def is_master(args): | |
| return args.distributed_rank == 0 | |
| def init_distributed_mode(rank, args): | |
| if "WORLD_SIZE" in os.environ: | |
| args.world_size = int(os.environ["WORLD_SIZE"]) | |
| if args.launcher == 'spawn': # single node with multiprocessing.spawn | |
| args.world_size = args.num_gpus | |
| args.rank = rank | |
| args.gpu = rank | |
| elif 'RANK' in os.environ: | |
| args.rank = int(os.environ["RANK"]) | |
| args.gpu = int(os.environ['LOCAL_RANK']) | |
| elif 'SLURM_PROCID' in os.environ: | |
| args.rank = int(os.environ['SLURM_PROCID']) | |
| args.gpu = args.rank % torch.cuda.device_count() | |
| if args.world_size == 1: | |
| return | |
| if 'MASTER_ADDR' in os.environ: | |
| args.dist_url = 'tcp://{}:{}'.format(os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) | |
| print(f'gpu={args.gpu}, rank={args.rank}, world_size={args.world_size}') | |
| args.distributed = True | |
| torch.cuda.set_device(args.gpu) | |
| args.dist_backend = 'nccl' | |
| print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) | |
| torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, | |
| world_size=args.world_size, rank=args.rank) | |
| torch.distributed.barrier() | |
| def gather_list_and_concat(tensor): | |
| gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())] | |
| dist.all_gather(gather_t, tensor) | |
| return torch.cat(gather_t) | |
| def get_rank(): | |
| return dist.get_rank() | |
| def get_world_size(): | |
| return dist.get_world_size() | |
| def get_default_group(): | |
| return dist.group.WORLD | |
| def all_gather_list(data, group=None, max_size=16384): | |
| """Gathers arbitrary data from all nodes into a list. | |
| Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python | |
| data. Note that *data* must be picklable. | |
| Args: | |
| data (Any): data from the local worker to be gathered on other workers | |
| group (optional): group of the collective | |
| max_size (int, optional): maximum size of the data to be gathered | |
| across workers | |
| """ | |
| rank = get_rank() | |
| world_size = get_world_size() | |
| buffer_size = max_size * world_size | |
| if not hasattr(all_gather_list, '_buffer') or \ | |
| all_gather_list._buffer.numel() < buffer_size: | |
| all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) | |
| all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() | |
| buffer = all_gather_list._buffer | |
| buffer.zero_() | |
| cpu_buffer = all_gather_list._cpu_buffer | |
| data = data.cpu() | |
| enc = pickle.dumps(data) | |
| enc_size = len(enc) | |
| header_size = 4 # size of header that contains the length of the encoded data | |
| size = header_size + enc_size | |
| if size > max_size: | |
| raise ValueError('encoded data size ({}) exceeds max_size ({})'.format(size, max_size)) | |
| header = struct.pack(">I", enc_size) | |
| cpu_buffer[:size] = torch.ByteTensor(list(header + enc)) | |
| start = rank * max_size | |
| buffer[start:start + size].copy_(cpu_buffer[:size]) | |
| all_reduce(buffer, group=group) | |
| buffer = buffer.cpu() | |
| try: | |
| result = [] | |
| for i in range(world_size): | |
| out_buffer = buffer[i * max_size:(i + 1) * max_size] | |
| enc_size, = struct.unpack(">I", bytes(out_buffer[:header_size].tolist())) | |
| if enc_size > 0: | |
| result.append(pickle.loads(bytes(out_buffer[header_size:header_size + enc_size].tolist()))) | |
| return result | |
| except pickle.UnpicklingError: | |
| raise Exception( | |
| 'Unable to unpickle data from other workers. all_gather_list requires all ' | |
| 'workers to enter the function together, so this error usually indicates ' | |
| 'that the workers have fallen out of sync somehow. Workers can fall out of ' | |
| 'sync if one of them runs out of memory, or if there are other conditions ' | |
| 'in your training script that can cause one worker to finish an epoch ' | |
| 'while other workers are still iterating over their portions of the data. ' | |
| 'Try rerunning with --ddp-backend=no_c10d and see if that helps.' | |
| ) | |
| def all_reduce_dict( | |
| data: Mapping[str, Any], | |
| device, | |
| group=None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| AllReduce a dictionary of values across workers. We separately | |
| reduce items that are already on the device and items on CPU for | |
| better performance. | |
| Args: | |
| data (Mapping[str, Any]): dictionary of data to all-reduce, but | |
| cannot be a nested dictionary | |
| device (torch.device): device for the reduction | |
| group (optional): group of the collective | |
| """ | |
| data_keys = list(data.keys()) | |
| # We want to separately reduce items that are already on the | |
| # device and items on CPU for performance reasons. | |
| cpu_data = OrderedDict() | |
| device_data = OrderedDict() | |
| for k in data_keys: | |
| t = data[k] | |
| if not torch.is_tensor(t): | |
| cpu_data[k] = torch.tensor(t, dtype=torch.double) | |
| elif t.device.type != device.type: | |
| cpu_data[k] = t.to(dtype=torch.double) | |
| else: | |
| device_data[k] = t.to(dtype=torch.double) | |
| def _all_reduce_dict(data: OrderedDict): | |
| if len(data) == 0: | |
| return data | |
| buf = torch.stack(list(data.values())).to(device=device) | |
| all_reduce(buf, group=group) | |
| return {k: buf[i] for i, k in enumerate(data)} | |
| cpu_data = _all_reduce_dict(cpu_data) | |
| device_data = _all_reduce_dict(device_data) | |
| def get_from_stack(key): | |
| if key in cpu_data: | |
| return cpu_data[key] | |
| elif key in device_data: | |
| return device_data[key] | |
| raise KeyError | |
| return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) | |
| def get_shared_folder() -> Path: | |
| user = os.getenv("USER") | |
| if Path("/checkpoint/").is_dir(): | |
| p = Path(f"/checkpoint/{user}/experiments") | |
| p.mkdir(exist_ok=True) | |
| return p | |
| else: | |
| p = Path(f"/tmp/experiments") | |
| p.mkdir(exist_ok=True) | |
| return p | |
| def get_init_file(): | |
| # Init file must not exist, but it's parent dir must exist. | |
| os.makedirs(str(get_shared_folder()), exist_ok=True) | |
| init_file = Path(str(get_shared_folder()) + f"/{uuid.uuid4().hex}_init") | |
| if init_file.exists(): | |
| os.remove(str(init_file)) | |
| return init_file | |