Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import importlib | |
| import random | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| from torch.optim import SGD, Adam, RMSprop, lr_scheduler | |
| from imaginaire.optimizers import Fromage, Madam | |
| from imaginaire.utils.distributed import get_rank, get_world_size | |
| from imaginaire.utils.distributed import master_only_print as print | |
| from imaginaire.utils.init_weight import weights_init, weights_rescale | |
| from imaginaire.utils.model_average import ModelAverage | |
| def set_random_seed(seed, by_rank=False): | |
| r"""Set random seeds for everything. | |
| Args: | |
| seed (int): Random seed. | |
| by_rank (bool): | |
| """ | |
| if by_rank: | |
| seed += get_rank() | |
| print(f"Using random seed {seed}") | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def get_trainer(cfg, net_G, net_D=None, | |
| opt_G=None, opt_D=None, | |
| sch_G=None, sch_D=None, | |
| train_data_loader=None, | |
| val_data_loader=None): | |
| """Return the trainer object. | |
| Args: | |
| cfg (Config): Loaded config object. | |
| net_G (obj): Generator network object. | |
| net_D (obj): Discriminator network object. | |
| opt_G (obj): Generator optimizer object. | |
| opt_D (obj): Discriminator optimizer object. | |
| sch_G (obj): Generator optimizer scheduler object. | |
| sch_D (obj): Discriminator optimizer scheduler object. | |
| train_data_loader (obj): Train data loader. | |
| val_data_loader (obj): Validation data loader. | |
| Returns: | |
| (obj): Trainer object. | |
| """ | |
| trainer_lib = importlib.import_module(cfg.trainer.type) | |
| trainer = trainer_lib.Trainer(cfg, net_G, net_D, | |
| opt_G, opt_D, | |
| sch_G, sch_D, | |
| train_data_loader, val_data_loader) | |
| return trainer | |
| def get_model_optimizer_and_scheduler(cfg, seed=0): | |
| r"""Return the networks, the optimizers, and the schedulers. We will | |
| first set the random seed to a fixed value so that each GPU copy will be | |
| initialized to have the same network weights. We will then use different | |
| random seeds for different GPUs. After this we will wrap the generator | |
| with a moving average model if applicable. It is followed by getting the | |
| optimizers and data distributed data parallel wrapping. | |
| Args: | |
| cfg (obj): Global configuration. | |
| seed (int): Random seed. | |
| Returns: | |
| (dict): | |
| - net_G (obj): Generator network object. | |
| - net_D (obj): Discriminator network object. | |
| - opt_G (obj): Generator optimizer object. | |
| - opt_D (obj): Discriminator optimizer object. | |
| - sch_G (obj): Generator optimizer scheduler object. | |
| - sch_D (obj): Discriminator optimizer scheduler object. | |
| """ | |
| # We first set the random seed to be the same so that we initialize each | |
| # copy of the network in exactly the same way so that they have the same | |
| # weights and other parameters. The true seed will be the seed. | |
| set_random_seed(seed, by_rank=False) | |
| # Construct networks | |
| lib_G = importlib.import_module(cfg.gen.type) | |
| lib_D = importlib.import_module(cfg.dis.type) | |
| net_G = lib_G.Generator(cfg.gen, cfg.data) | |
| net_D = lib_D.Discriminator(cfg.dis, cfg.data) | |
| print('Initialize net_G and net_D weights using ' | |
| 'type: {} gain: {}'.format(cfg.trainer.init.type, | |
| cfg.trainer.init.gain)) | |
| init_bias = getattr(cfg.trainer.init, 'bias', None) | |
| net_G.apply(weights_init( | |
| cfg.trainer.init.type, cfg.trainer.init.gain, init_bias)) | |
| net_D.apply(weights_init( | |
| cfg.trainer.init.type, cfg.trainer.init.gain, init_bias)) | |
| net_G.apply(weights_rescale()) | |
| net_D.apply(weights_rescale()) | |
| # for name, p in net_G.named_parameters(): | |
| # if 'modulation' in name and 'bias' in name: | |
| # nn.init.constant_(p.data, 1.) | |
| net_G = net_G.to('cuda') | |
| net_D = net_D.to('cuda') | |
| # Different GPU copies of the same model will receive noises | |
| # initialized with different random seeds (if applicable) thanks to the | |
| # set_random_seed command (GPU #K has random seed = args.seed + K). | |
| set_random_seed(seed, by_rank=True) | |
| print('net_G parameter count: {:,}'.format(_calculate_model_size(net_G))) | |
| print('net_D parameter count: {:,}'.format(_calculate_model_size(net_D))) | |
| # Optimizer | |
| opt_G = get_optimizer(cfg.gen_opt, net_G) | |
| opt_D = get_optimizer(cfg.dis_opt, net_D) | |
| net_G, net_D, opt_G, opt_D = \ | |
| wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D) | |
| # Scheduler | |
| sch_G = get_scheduler(cfg.gen_opt, opt_G) | |
| sch_D = get_scheduler(cfg.dis_opt, opt_D) | |
| return net_G, net_D, opt_G, opt_D, sch_G, sch_D | |
| def wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D): | |
| r"""Wrap the networks and the optimizers with AMP DDP and (optionally) | |
| model average. | |
| Args: | |
| cfg (obj): Global configuration. | |
| net_G (obj): Generator network object. | |
| net_D (obj): Discriminator network object. | |
| opt_G (obj): Generator optimizer object. | |
| opt_D (obj): Discriminator optimizer object. | |
| Returns: | |
| (dict): | |
| - net_G (obj): Generator network object. | |
| - net_D (obj): Discriminator network object. | |
| - opt_G (obj): Generator optimizer object. | |
| - opt_D (obj): Discriminator optimizer object. | |
| """ | |
| # Apply model average wrapper. | |
| if cfg.trainer.model_average_config.enabled: | |
| if hasattr(cfg.trainer.model_average_config, 'g_smooth_img'): | |
| # Specifies half-life of the running average of generator weights. | |
| cfg.trainer.model_average_config.beta = \ | |
| 0.5 ** (cfg.data.train.batch_size * | |
| get_world_size() / cfg.trainer.model_average_config.g_smooth_img) | |
| print(f"EMA Decay Factor: {cfg.trainer.model_average_config.beta}") | |
| net_G = ModelAverage(net_G, cfg.trainer.model_average_config.beta, | |
| cfg.trainer.model_average_config.start_iteration, | |
| cfg.trainer.model_average_config.remove_sn) | |
| if cfg.trainer.model_average_config.enabled: | |
| net_G_module = net_G.module | |
| else: | |
| net_G_module = net_G | |
| if hasattr(net_G_module, 'custom_init'): | |
| net_G_module.custom_init() | |
| net_G = _wrap_model(cfg, net_G) | |
| net_D = _wrap_model(cfg, net_D) | |
| return net_G, net_D, opt_G, opt_D | |
| def _calculate_model_size(model): | |
| r"""Calculate number of parameters in a PyTorch network. | |
| Args: | |
| model (obj): PyTorch network. | |
| Returns: | |
| (int): Number of parameters. | |
| """ | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| class WrappedModel(nn.Module): | |
| r"""Dummy wrapping the module. | |
| """ | |
| def __init__(self, module): | |
| super(WrappedModel, self).__init__() | |
| self.module = module | |
| def forward(self, *args, **kwargs): | |
| r"""PyTorch module forward function overload.""" | |
| return self.module(*args, **kwargs) | |
| def _wrap_model(cfg, model): | |
| r"""Wrap a model for distributed data parallel training. | |
| Args: | |
| model (obj): PyTorch network model. | |
| Returns: | |
| (obj): Wrapped PyTorch network model. | |
| """ | |
| if torch.distributed.is_available() and dist.is_initialized(): | |
| # ddp = cfg.trainer.distributed_data_parallel | |
| find_unused_parameters = cfg.trainer.distributed_data_parallel_params.find_unused_parameters | |
| return torch.nn.parallel.DistributedDataParallel( | |
| model, | |
| device_ids=[cfg.local_rank], | |
| output_device=cfg.local_rank, | |
| find_unused_parameters=find_unused_parameters, | |
| broadcast_buffers=False | |
| ) | |
| # if ddp == 'pytorch': | |
| # return torch.nn.parallel.DistributedDataParallel( | |
| # model, | |
| # device_ids=[cfg.local_rank], | |
| # output_device=cfg.local_rank, | |
| # find_unused_parameters=find_unused_parameters, | |
| # broadcast_buffers=False) | |
| # else: | |
| # delay_allreduce = cfg.trainer.delay_allreduce | |
| # return apex.parallel.DistributedDataParallel( | |
| # model, delay_allreduce=delay_allreduce) | |
| else: | |
| return WrappedModel(model) | |
| def get_scheduler(cfg_opt, opt): | |
| """Return the scheduler object. | |
| Args: | |
| cfg_opt (obj): Config for the specific optimization module (gen/dis). | |
| opt (obj): PyTorch optimizer object. | |
| Returns: | |
| (obj): Scheduler | |
| """ | |
| if cfg_opt.lr_policy.type == 'step': | |
| scheduler = lr_scheduler.StepLR( | |
| opt, | |
| step_size=cfg_opt.lr_policy.step_size, | |
| gamma=cfg_opt.lr_policy.gamma) | |
| elif cfg_opt.lr_policy.type == 'constant': | |
| scheduler = lr_scheduler.LambdaLR(opt, lambda x: 1) | |
| elif cfg_opt.lr_policy.type == 'linear': | |
| # Start linear decay from here. | |
| decay_start = cfg_opt.lr_policy.decay_start | |
| # End linear decay here. | |
| # Continue to train using the lowest learning rate till the end. | |
| decay_end = cfg_opt.lr_policy.decay_end | |
| # Lowest learning rate multiplier. | |
| decay_target = cfg_opt.lr_policy.decay_target | |
| def sch(x): | |
| return min( | |
| max(((x - decay_start) * decay_target + decay_end - x) / ( | |
| decay_end - decay_start | |
| ), decay_target), 1. | |
| ) | |
| scheduler = lr_scheduler.LambdaLR(opt, lambda x: sch(x)) | |
| else: | |
| return NotImplementedError('Learning rate policy {} not implemented.'. | |
| format(cfg_opt.lr_policy.type)) | |
| return scheduler | |
| def get_optimizer(cfg_opt, net): | |
| r"""Return the scheduler object. | |
| Args: | |
| cfg_opt (obj): Config for the specific optimization module (gen/dis). | |
| net (obj): PyTorch network object. | |
| Returns: | |
| (obj): Pytorch optimizer | |
| """ | |
| if hasattr(net, 'get_param_groups'): | |
| # Allow the network to use different hyper-parameters (e.g., learning | |
| # rate) for different parameters. | |
| params = net.get_param_groups(cfg_opt) | |
| else: | |
| params = net.parameters() | |
| return get_optimizer_for_params(cfg_opt, params) | |
| def get_optimizer_for_params(cfg_opt, params): | |
| r"""Return the scheduler object. | |
| Args: | |
| cfg_opt (obj): Config for the specific optimization module (gen/dis). | |
| params (obj): Parameters to be trained by the parameters. | |
| Returns: | |
| (obj): Optimizer | |
| """ | |
| # We will use fuse optimizers by default. | |
| fused_opt = cfg_opt.fused_opt | |
| try: | |
| from apex.optimizers import FusedAdam | |
| except: # noqa | |
| fused_opt = False | |
| if cfg_opt.type == 'adam': | |
| if fused_opt: | |
| opt = FusedAdam(params, | |
| lr=cfg_opt.lr, eps=cfg_opt.eps, | |
| betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2)) | |
| else: | |
| opt = Adam(params, | |
| lr=cfg_opt.lr, eps=cfg_opt.eps, | |
| betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2)) | |
| elif cfg_opt.type == 'madam': | |
| g_bound = getattr(cfg_opt, 'g_bound', None) | |
| opt = Madam(params, lr=cfg_opt.lr, | |
| scale=cfg_opt.scale, g_bound=g_bound) | |
| elif cfg_opt.type == 'fromage': | |
| opt = Fromage(params, lr=cfg_opt.lr) | |
| elif cfg_opt.type == 'rmsprop': | |
| opt = RMSprop(params, lr=cfg_opt.lr, | |
| eps=cfg_opt.eps, weight_decay=cfg_opt.weight_decay) | |
| elif cfg_opt.type == 'sgd': | |
| if fused_opt: | |
| from apex.optimizers import FusedSGD | |
| opt = FusedSGD(params, | |
| lr=cfg_opt.lr, | |
| momentum=cfg_opt.momentum, | |
| weight_decay=cfg_opt.weight_decay) | |
| else: | |
| opt = SGD(params, | |
| lr=cfg_opt.lr, | |
| momentum=cfg_opt.momentum, | |
| weight_decay=cfg_opt.weight_decay) | |
| else: | |
| raise NotImplementedError( | |
| 'Optimizer {} is not yet implemented.'.format(cfg_opt.type)) | |
| return opt | |