Spaces:
Running
Running
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel | |
| from mmcv.parallel.scatter_gather import scatter_kwargs | |
| from torch.cuda._utils import _get_device_index | |
| class DistributedDataParallelWrapper(nn.Module): | |
| """A DistributedDataParallel wrapper for models in 3D mesh estimation task. | |
| In 3D mesh estimation task, there is a need to wrap different modules in | |
| the models with separate DistributedDataParallel. Otherwise, it will cause | |
| errors for GAN training. | |
| More specific, the GAN model, usually has two sub-modules: | |
| generator and discriminator. If we wrap both of them in one | |
| standard DistributedDataParallel, it will cause errors during training, | |
| because when we update the parameters of the generator (or discriminator), | |
| the parameters of the discriminator (or generator) is not updated, which is | |
| not allowed for DistributedDataParallel. | |
| So we design this wrapper to separately wrap DistributedDataParallel | |
| for generator and discriminator. | |
| In this wrapper, we perform two operations: | |
| 1. Wrap the modules in the models with separate MMDistributedDataParallel. | |
| Note that only modules with parameters will be wrapped. | |
| 2. Do scatter operation for 'forward', 'train_step' and 'val_step'. | |
| Note that the arguments of this wrapper is the same as those in | |
| `torch.nn.parallel.distributed.DistributedDataParallel`. | |
| Args: | |
| module (nn.Module): Module that needs to be wrapped. | |
| device_ids (list[int | `torch.device`]): Same as that in | |
| `torch.nn.parallel.distributed.DistributedDataParallel`. | |
| dim (int, optional): Same as that in the official scatter function in | |
| pytorch. Defaults to 0. | |
| broadcast_buffers (bool): Same as that in | |
| `torch.nn.parallel.distributed.DistributedDataParallel`. | |
| Defaults to False. | |
| find_unused_parameters (bool, optional): Same as that in | |
| `torch.nn.parallel.distributed.DistributedDataParallel`. | |
| Traverse the autograd graph of all tensors contained in returned | |
| value of the wrapped module’s forward function. Defaults to False. | |
| kwargs (dict): Other arguments used in | |
| `torch.nn.parallel.distributed.DistributedDataParallel`. | |
| """ | |
| def __init__(self, | |
| module, | |
| device_ids, | |
| dim=0, | |
| broadcast_buffers=False, | |
| find_unused_parameters=False, | |
| **kwargs): | |
| super().__init__() | |
| assert len(device_ids) == 1, ( | |
| 'Currently, DistributedDataParallelWrapper only supports one' | |
| 'single CUDA device for each process.' | |
| f'The length of device_ids must be 1, but got {len(device_ids)}.') | |
| self.module = module | |
| self.dim = dim | |
| self.to_ddp( | |
| device_ids=device_ids, | |
| dim=dim, | |
| broadcast_buffers=broadcast_buffers, | |
| find_unused_parameters=find_unused_parameters, | |
| **kwargs) | |
| self.output_device = _get_device_index(device_ids[0], True) | |
| def to_ddp(self, device_ids, dim, broadcast_buffers, | |
| find_unused_parameters, **kwargs): | |
| """Wrap models with separate MMDistributedDataParallel. | |
| It only wraps the modules with parameters. | |
| """ | |
| for name, module in self.module._modules.items(): | |
| if next(module.parameters(), None) is None: | |
| module = module.cuda() | |
| elif all(not p.requires_grad for p in module.parameters()): | |
| module = module.cuda() | |
| else: | |
| module = MMDistributedDataParallel( | |
| module.cuda(), | |
| device_ids=device_ids, | |
| dim=dim, | |
| broadcast_buffers=broadcast_buffers, | |
| find_unused_parameters=find_unused_parameters, | |
| **kwargs) | |
| self.module._modules[name] = module | |
| def scatter(self, inputs, kwargs, device_ids): | |
| """Scatter function. | |
| Args: | |
| inputs (Tensor): Input Tensor. | |
| kwargs (dict): Args for | |
| ``mmcv.parallel.scatter_gather.scatter_kwargs``. | |
| device_ids (int): Device id. | |
| """ | |
| return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) | |
| def forward(self, *inputs, **kwargs): | |
| """Forward function. | |
| Args: | |
| inputs (tuple): Input data. | |
| kwargs (dict): Args for | |
| ``mmcv.parallel.scatter_gather.scatter_kwargs``. | |
| """ | |
| inputs, kwargs = self.scatter(inputs, kwargs, | |
| [torch.cuda.current_device()]) | |
| return self.module(*inputs[0], **kwargs[0]) | |
| def train_step(self, *inputs, **kwargs): | |
| """Train step function. | |
| Args: | |
| inputs (Tensor): Input Tensor. | |
| kwargs (dict): Args for | |
| ``mmcv.parallel.scatter_gather.scatter_kwargs``. | |
| """ | |
| inputs, kwargs = self.scatter(inputs, kwargs, | |
| [torch.cuda.current_device()]) | |
| output = self.module.train_step(*inputs[0], **kwargs[0]) | |
| return output | |
| def val_step(self, *inputs, **kwargs): | |
| """Validation step function. | |
| Args: | |
| inputs (tuple): Input data. | |
| kwargs (dict): Args for ``scatter_kwargs``. | |
| """ | |
| inputs, kwargs = self.scatter(inputs, kwargs, | |
| [torch.cuda.current_device()]) | |
| output = self.module.val_step(*inputs[0], **kwargs[0]) | |
| return output |