Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |
| import torch | |
| from torch import distributed as dist | |
| from torch import nn | |
| import pickle | |
| from collections import OrderedDict | |
| from .dist import _get_global_gloo_group, get_world_size | |
| ASYNC_NORM = ( | |
| nn.BatchNorm1d, | |
| nn.BatchNorm2d, | |
| nn.BatchNorm3d, | |
| nn.InstanceNorm1d, | |
| nn.InstanceNorm2d, | |
| nn.InstanceNorm3d, | |
| ) | |
| __all__ = [ | |
| "get_async_norm_states", | |
| "pyobj2tensor", | |
| "tensor2pyobj", | |
| "all_reduce", | |
| "all_reduce_norm", | |
| ] | |
| def get_async_norm_states(module): | |
| async_norm_states = OrderedDict() | |
| for name, child in module.named_modules(): | |
| if isinstance(child, ASYNC_NORM): | |
| for k, v in child.state_dict().items(): | |
| async_norm_states[".".join([name, k])] = v | |
| return async_norm_states | |
| def pyobj2tensor(pyobj, device="cuda"): | |
| """serialize picklable python object to tensor""" | |
| storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj)) | |
| return torch.ByteTensor(storage).to(device=device) | |
| def tensor2pyobj(tensor): | |
| """deserialize tensor to picklable python object""" | |
| return pickle.loads(tensor.cpu().numpy().tobytes()) | |
| def _get_reduce_op(op_name): | |
| return { | |
| "sum": dist.ReduceOp.SUM, | |
| "mean": dist.ReduceOp.SUM, | |
| }[op_name.lower()] | |
| def all_reduce(py_dict, op="sum", group=None): | |
| """ | |
| Apply all reduce function for python dict object. | |
| NOTE: make sure that every py_dict has the same keys and values are in the same shape. | |
| Args: | |
| py_dict (dict): dict to apply all reduce op. | |
| op (str): operator, could be "sum" or "mean". | |
| """ | |
| world_size = get_world_size() | |
| if world_size == 1: | |
| return py_dict | |
| if group is None: | |
| group = _get_global_gloo_group() | |
| if dist.get_world_size(group) == 1: | |
| return py_dict | |
| # all reduce logic across different devices. | |
| py_key = list(py_dict.keys()) | |
| py_key_tensor = pyobj2tensor(py_key) | |
| dist.broadcast(py_key_tensor, src=0) | |
| py_key = tensor2pyobj(py_key_tensor) | |
| tensor_shapes = [py_dict[k].shape for k in py_key] | |
| tensor_numels = [py_dict[k].numel() for k in py_key] | |
| flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key]) | |
| dist.all_reduce(flatten_tensor, op=_get_reduce_op(op)) | |
| if op == "mean": | |
| flatten_tensor /= world_size | |
| split_tensors = [ | |
| x.reshape(shape) | |
| for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes) | |
| ] | |
| return OrderedDict({k: v for k, v in zip(py_key, split_tensors)}) | |
| def all_reduce_norm(module): | |
| """ | |
| All reduce norm statistics in different devices. | |
| """ | |
| states = get_async_norm_states(module) | |
| states = all_reduce(states, op="mean") | |
| module.load_state_dict(states, strict=False) | |