Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import json | |
| import os | |
| import time | |
| import torch | |
| import torchvision | |
| import wandb | |
| from torch.cuda.amp import GradScaler, autocast | |
| from tqdm import tqdm | |
| from imaginaire.utils.distributed import is_master, master_only | |
| from imaginaire.utils.distributed import master_only_print as print | |
| from imaginaire.utils.io import save_pilimage_in_jpeg | |
| from imaginaire.utils.meters import Meter | |
| from imaginaire.utils.misc import to_cuda, to_device, requires_grad, to_channels_last | |
| from imaginaire.utils.model_average import (calibrate_batch_norm_momentum, | |
| reset_batch_norm) | |
| from imaginaire.utils.visualization import tensor2pilimage | |
| class BaseTrainer(object): | |
| r"""Base trainer. We expect that all trainers inherit this class. | |
| Args: | |
| cfg (obj): Global configuration. | |
| net_G (obj): Generator network. | |
| net_D (obj): Discriminator network. | |
| opt_G (obj): Optimizer for the generator network. | |
| opt_D (obj): Optimizer for the discriminator network. | |
| sch_G (obj): Scheduler for the generator optimizer. | |
| sch_D (obj): Scheduler for the discriminator optimizer. | |
| train_data_loader (obj): Train data loader. | |
| val_data_loader (obj): Validation data loader. | |
| """ | |
| def __init__(self, | |
| cfg, | |
| net_G, | |
| net_D, | |
| opt_G, | |
| opt_D, | |
| sch_G, | |
| sch_D, | |
| train_data_loader, | |
| val_data_loader): | |
| super(BaseTrainer, self).__init__() | |
| print('Setup trainer.') | |
| # Initialize models and data loaders. | |
| self.cfg = cfg | |
| self.net_G = net_G | |
| if cfg.trainer.model_average_config.enabled: | |
| # Two wrappers (DDP + model average). | |
| self.net_G_module = self.net_G.module.module | |
| else: | |
| # One wrapper (DDP) | |
| self.net_G_module = self.net_G.module | |
| self.val_data_loader = val_data_loader | |
| self.is_inference = train_data_loader is None | |
| self.net_D = net_D | |
| self.opt_G = opt_G | |
| self.opt_D = opt_D | |
| self.sch_G = sch_G | |
| self.sch_D = sch_D | |
| self.train_data_loader = train_data_loader | |
| if self.cfg.trainer.channels_last: | |
| self.net_G = self.net_G.to(memory_format=torch.channels_last) | |
| self.net_D = self.net_D.to(memory_format=torch.channels_last) | |
| # Initialize amp. | |
| if self.cfg.trainer.amp_config.enabled: | |
| print("Using automatic mixed precision training.") | |
| self.scaler_G = GradScaler(**vars(self.cfg.trainer.amp_config)) | |
| self.scaler_D = GradScaler(**vars(self.cfg.trainer.amp_config)) | |
| # In order to check whether the discriminator/generator has | |
| # skipped the last parameter update due to gradient overflow. | |
| self.last_step_count_G = 0 | |
| self.last_step_count_D = 0 | |
| self.skipped_G = False | |
| self.skipped_D = False | |
| # Initialize data augmentation policy. | |
| self.aug_policy = cfg.trainer.aug_policy | |
| print("Augmentation policy: {}".format(self.aug_policy)) | |
| # Initialize loss functions. | |
| # All loss names have weights. Some have criterion modules. | |
| # Mapping from loss names to criterion modules. | |
| self.criteria = torch.nn.ModuleDict() | |
| # Mapping from loss names to loss weights. | |
| self.weights = dict() | |
| self.losses = dict(gen_update=dict(), dis_update=dict()) | |
| self.gen_losses = self.losses['gen_update'] | |
| self.dis_losses = self.losses['dis_update'] | |
| self._init_loss(cfg) | |
| for loss_name, loss_weight in self.weights.items(): | |
| print("Loss {:<20} Weight {}".format(loss_name, loss_weight)) | |
| if loss_name in self.criteria.keys() and \ | |
| self.criteria[loss_name] is not None: | |
| self.criteria[loss_name].to('cuda') | |
| if self.is_inference: | |
| # The initialization steps below can be skipped during inference. | |
| return | |
| # Initialize logging attributes. | |
| self.current_iteration = 0 | |
| self.current_epoch = 0 | |
| self.start_iteration_time = None | |
| self.start_epoch_time = None | |
| self.elapsed_iteration_time = 0 | |
| self.time_iteration = None | |
| self.time_epoch = None | |
| self.best_fid = None | |
| if self.cfg.speed_benchmark: | |
| self.accu_gen_forw_iter_time = 0 | |
| self.accu_gen_loss_iter_time = 0 | |
| self.accu_gen_back_iter_time = 0 | |
| self.accu_gen_step_iter_time = 0 | |
| self.accu_gen_avg_iter_time = 0 | |
| self.accu_dis_forw_iter_time = 0 | |
| self.accu_dis_loss_iter_time = 0 | |
| self.accu_dis_back_iter_time = 0 | |
| self.accu_dis_step_iter_time = 0 | |
| # Initialize tensorboard and hparams. | |
| self._init_tensorboard() | |
| self._init_hparams() | |
| # Initialize validation parameters. | |
| self.val_sample_size = getattr(cfg.trainer, 'val_sample_size', 50000) | |
| self.kid_num_subsets = getattr(cfg.trainer, 'kid_num_subsets', 10) | |
| self.kid_subset_size = self.val_sample_size // self.kid_num_subsets | |
| self.metrics_path = os.path.join(torch.hub.get_dir(), 'metrics') | |
| self.best_metrics = {} | |
| self.eval_networks = getattr(cfg.trainer, 'eval_network', ['clean_inception']) | |
| if self.cfg.metrics_iter is None: | |
| self.cfg.metrics_iter = self.cfg.snapshot_save_iter | |
| if self.cfg.metrics_epoch is None: | |
| self.cfg.metrics_epoch = self.cfg.snapshot_save_epoch | |
| # AWS credentials. | |
| if hasattr(cfg, 'aws_credentials_file'): | |
| with open(cfg.aws_credentials_file) as fin: | |
| self.credentials = json.load(fin) | |
| else: | |
| self.credentials = None | |
| if 'TORCH_HOME' not in os.environ: | |
| os.environ['TORCH_HOME'] = os.path.join( | |
| os.environ['HOME'], ".cache") | |
| def _init_tensorboard(self): | |
| r"""Initialize the tensorboard. Different algorithms might require | |
| different performance metrics. Hence, custom tensorboard | |
| initialization might be necessary. | |
| """ | |
| # Logging frequency: self.cfg.logging_iter | |
| self.meters = {} | |
| # Logging frequency: self.cfg.snapshot_save_iter | |
| self.metric_meters = {} | |
| # Logging frequency: self.cfg.image_display_iter | |
| self.image_meter = Meter('images', reduce=False) | |
| def _init_hparams(self): | |
| r"""Initialize a dictionary of hyperparameters that we want to monitor | |
| in the HParams dashboard in tensorBoard. | |
| """ | |
| self.hparam_dict = {} | |
| def _write_tensorboard(self): | |
| r"""Write values to tensorboard. By default, we will log the time used | |
| per iteration, time used per epoch, generator learning rate, and | |
| discriminator learning rate. We will log all the losses as well as | |
| custom meters. | |
| """ | |
| # Logs that are shared by all models. | |
| self._write_to_meters({'time/iteration': self.time_iteration, | |
| 'time/epoch': self.time_epoch, | |
| 'optim/gen_lr': self.sch_G.get_last_lr()[0], | |
| 'optim/dis_lr': self.sch_D.get_last_lr()[0]}, | |
| self.meters, | |
| reduce=False) | |
| # Logs for loss values. Different models have different losses. | |
| self._write_loss_meters() | |
| # Other custom logs. | |
| self._write_custom_meters() | |
| def _write_loss_meters(self): | |
| r"""Write all loss values to tensorboard.""" | |
| for update, losses in self.losses.items(): | |
| # update is 'gen_update' or 'dis_update'. | |
| assert update == 'gen_update' or update == 'dis_update' | |
| for loss_name, loss in losses.items(): | |
| if loss is not None: | |
| full_loss_name = update + '/' + loss_name | |
| if full_loss_name not in self.meters.keys(): | |
| # Create a new meter if it doesn't exist. | |
| self.meters[full_loss_name] = Meter( | |
| full_loss_name, reduce=True) | |
| self.meters[full_loss_name].write(loss.item()) | |
| def _write_custom_meters(self): | |
| r"""Dummy member function to be overloaded by the child class. | |
| In the child class, you can write down whatever you want to track. | |
| """ | |
| pass | |
| def _write_to_meters(data, meters, reduce=True): | |
| r"""Write values to meters.""" | |
| if reduce or is_master(): | |
| for key, value in data.items(): | |
| if key not in meters: | |
| meters[key] = Meter(key, reduce=reduce) | |
| meters[key].write(value) | |
| def _flush_meters(self, meters): | |
| r"""Flush all meters using the current iteration.""" | |
| for meter in meters.values(): | |
| meter.flush(self.current_iteration) | |
| def _pre_save_checkpoint(self): | |
| r"""Implement the things you want to do before saving a checkpoint. | |
| For example, you can compute the K-mean features (pix2pixHD) before | |
| saving the model weights to a checkpoint. | |
| """ | |
| pass | |
| def save_checkpoint(self, current_epoch, current_iteration): | |
| r"""Save network weights, optimizer parameters, scheduler parameters | |
| to a checkpoint. | |
| """ | |
| self._pre_save_checkpoint() | |
| _save_checkpoint(self.cfg, | |
| self.net_G, self.net_D, | |
| self.opt_G, self.opt_D, | |
| self.sch_G, self.sch_D, | |
| current_epoch, current_iteration) | |
| def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True): | |
| r"""Load network weights, optimizer parameters, scheduler parameters | |
| from a checkpoint. | |
| Args: | |
| cfg (obj): Global configuration. | |
| checkpoint_path (str): Path to the checkpoint. | |
| resume (bool or None): If not ``None``, will determine whether or | |
| not to load optimizers in addition to network weights. | |
| """ | |
| if os.path.exists(checkpoint_path): | |
| # If checkpoint_path exists, we will load its weights to | |
| # initialize our network. | |
| if resume is None: | |
| resume = False | |
| elif os.path.exists(os.path.join(cfg.logdir, 'latest_checkpoint.txt')): | |
| # This is for resuming the training from the previously saved | |
| # checkpoint. | |
| fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt') | |
| with open(fn, 'r') as f: | |
| line = f.read().splitlines() | |
| checkpoint_path = os.path.join(cfg.logdir, line[0].split(' ')[-1]) | |
| if resume is None: | |
| resume = True | |
| else: | |
| # checkpoint not found and not specified. We will train | |
| # everything from scratch. | |
| current_epoch = 0 | |
| current_iteration = 0 | |
| print('No checkpoint found.') | |
| resume = False | |
| return resume, current_epoch, current_iteration | |
| # Load checkpoint | |
| checkpoint = torch.load( | |
| checkpoint_path, map_location=lambda storage, loc: storage) | |
| current_epoch = 0 | |
| current_iteration = 0 | |
| if resume: | |
| self.net_G.load_state_dict(checkpoint['net_G'], strict=self.cfg.trainer.strict_resume) | |
| if not self.is_inference: | |
| self.net_D.load_state_dict(checkpoint['net_D'], strict=self.cfg.trainer.strict_resume) | |
| if 'opt_G' in checkpoint: | |
| current_epoch = checkpoint['current_epoch'] | |
| current_iteration = checkpoint['current_iteration'] | |
| self.opt_G.load_state_dict(checkpoint['opt_G']) | |
| self.opt_D.load_state_dict(checkpoint['opt_D']) | |
| if load_sch: | |
| self.sch_G.load_state_dict(checkpoint['sch_G']) | |
| self.sch_D.load_state_dict(checkpoint['sch_D']) | |
| else: | |
| if self.cfg.gen_opt.lr_policy.iteration_mode: | |
| self.sch_G.last_epoch = current_iteration | |
| else: | |
| self.sch_G.last_epoch = current_epoch | |
| if self.cfg.dis_opt.lr_policy.iteration_mode: | |
| self.sch_D.last_epoch = current_iteration | |
| else: | |
| self.sch_D.last_epoch = current_epoch | |
| print('Load from: {}'.format(checkpoint_path)) | |
| else: | |
| print('Load network weights only.') | |
| else: | |
| try: | |
| self.net_G.load_state_dict(checkpoint['net_G'], strict=self.cfg.trainer.strict_resume) | |
| if 'net_D' in checkpoint: | |
| self.net_D.load_state_dict(checkpoint['net_D'], strict=self.cfg.trainer.strict_resume) | |
| except Exception: | |
| if self.cfg.trainer.model_average_config.enabled: | |
| net_G_module = self.net_G.module.module | |
| else: | |
| net_G_module = self.net_G.module | |
| if hasattr(net_G_module, 'load_pretrained_network'): | |
| net_G_module.load_pretrained_network(self.net_G, checkpoint['net_G']) | |
| print('Load generator weights only.') | |
| else: | |
| raise ValueError('Checkpoint cannot be loaded.') | |
| print('Done with loading the checkpoint.') | |
| return resume, current_epoch, current_iteration | |
| def start_of_epoch(self, current_epoch): | |
| r"""Things to do before an epoch. | |
| Args: | |
| current_epoch (int): Current number of epoch. | |
| """ | |
| self._start_of_epoch(current_epoch) | |
| self.current_epoch = current_epoch | |
| self.start_epoch_time = time.time() | |
| def start_of_iteration(self, data, current_iteration): | |
| r"""Things to do before an iteration. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| current_iteration (int): Current number of iteration. | |
| """ | |
| data = self._start_of_iteration(data, current_iteration) | |
| data = to_cuda(data) | |
| if self.cfg.trainer.channels_last: | |
| data = to_channels_last(data) | |
| self.current_iteration = current_iteration | |
| if not self.is_inference: | |
| self.net_D.train() | |
| self.net_G.train() | |
| # torch.cuda.synchronize() | |
| self.start_iteration_time = time.time() | |
| return data | |
| def end_of_iteration(self, data, current_epoch, current_iteration): | |
| r"""Things to do after an iteration. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| current_epoch (int): Current number of epoch. | |
| current_iteration (int): Current number of iteration. | |
| """ | |
| self.current_iteration = current_iteration | |
| self.current_epoch = current_epoch | |
| # Update the learning rate policy for the generator if operating in the | |
| # iteration mode. | |
| if self.cfg.gen_opt.lr_policy.iteration_mode: | |
| self.sch_G.step() | |
| # Update the learning rate policy for the discriminator if operating in | |
| # the iteration mode. | |
| if self.cfg.dis_opt.lr_policy.iteration_mode: | |
| self.sch_D.step() | |
| # Accumulate time | |
| # torch.cuda.synchronize() | |
| self.elapsed_iteration_time += time.time() - self.start_iteration_time | |
| # Logging. | |
| if current_iteration % self.cfg.logging_iter == 0: | |
| ave_t = self.elapsed_iteration_time / self.cfg.logging_iter | |
| self.time_iteration = ave_t | |
| print('Iteration: {}, average iter time: ' | |
| '{:6f}.'.format(current_iteration, ave_t)) | |
| self.elapsed_iteration_time = 0 | |
| if self.cfg.speed_benchmark: | |
| # Below code block only needed when analyzing computation | |
| # bottleneck. | |
| print('\tGenerator FWD time {:6f}'.format( | |
| self.accu_gen_forw_iter_time / self.cfg.logging_iter)) | |
| print('\tGenerator LOS time {:6f}'.format( | |
| self.accu_gen_loss_iter_time / self.cfg.logging_iter)) | |
| print('\tGenerator BCK time {:6f}'.format( | |
| self.accu_gen_back_iter_time / self.cfg.logging_iter)) | |
| print('\tGenerator STP time {:6f}'.format( | |
| self.accu_gen_step_iter_time / self.cfg.logging_iter)) | |
| print('\tGenerator AVG time {:6f}'.format( | |
| self.accu_gen_avg_iter_time / self.cfg.logging_iter)) | |
| print('\tDiscriminator FWD time {:6f}'.format( | |
| self.accu_dis_forw_iter_time / self.cfg.logging_iter)) | |
| print('\tDiscriminator LOS time {:6f}'.format( | |
| self.accu_dis_loss_iter_time / self.cfg.logging_iter)) | |
| print('\tDiscriminator BCK time {:6f}'.format( | |
| self.accu_dis_back_iter_time / self.cfg.logging_iter)) | |
| print('\tDiscriminator STP time {:6f}'.format( | |
| self.accu_dis_step_iter_time / self.cfg.logging_iter)) | |
| print('{:6f}'.format(ave_t)) | |
| self.accu_gen_forw_iter_time = 0 | |
| self.accu_gen_loss_iter_time = 0 | |
| self.accu_gen_back_iter_time = 0 | |
| self.accu_gen_step_iter_time = 0 | |
| self.accu_gen_avg_iter_time = 0 | |
| self.accu_dis_forw_iter_time = 0 | |
| self.accu_dis_loss_iter_time = 0 | |
| self.accu_dis_back_iter_time = 0 | |
| self.accu_dis_step_iter_time = 0 | |
| self._end_of_iteration(data, current_epoch, current_iteration) | |
| # Save everything to the checkpoint. | |
| if current_iteration % self.cfg.snapshot_save_iter == 0: | |
| if current_iteration >= self.cfg.snapshot_save_start_iter: | |
| self.save_checkpoint(current_epoch, current_iteration) | |
| # Compute metrics. | |
| if current_iteration % self.cfg.metrics_iter == 0: | |
| self.save_image(self._get_save_path('images', 'jpg'), data) | |
| self.write_metrics() | |
| # Compute image to be saved. | |
| elif current_iteration % self.cfg.image_save_iter == 0: | |
| self.save_image(self._get_save_path('images', 'jpg'), data) | |
| elif current_iteration % self.cfg.image_display_iter == 0: | |
| image_path = os.path.join(self.cfg.logdir, 'images', 'current.jpg') | |
| self.save_image(image_path, data) | |
| # Logging. | |
| self._write_tensorboard() | |
| if current_iteration % self.cfg.logging_iter == 0: | |
| # Write all logs to tensorboard. | |
| self._flush_meters(self.meters) | |
| from torch.distributed import barrier | |
| import torch.distributed as dist | |
| if dist.is_initialized(): | |
| barrier() | |
| def end_of_epoch(self, data, current_epoch, current_iteration): | |
| r"""Things to do after an epoch. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| current_epoch (int): Current number of epoch. | |
| current_iteration (int): Current number of iteration. | |
| """ | |
| # Update the learning rate policy for the generator if operating in the | |
| # epoch mode. | |
| self.current_iteration = current_iteration | |
| self.current_epoch = current_epoch | |
| if not self.cfg.gen_opt.lr_policy.iteration_mode: | |
| self.sch_G.step() | |
| # Update the learning rate policy for the discriminator if operating | |
| # in the epoch mode. | |
| if not self.cfg.dis_opt.lr_policy.iteration_mode: | |
| self.sch_D.step() | |
| elapsed_epoch_time = time.time() - self.start_epoch_time | |
| # Logging. | |
| print('Epoch: {}, total time: {:6f}.'.format(current_epoch, | |
| elapsed_epoch_time)) | |
| self.time_epoch = elapsed_epoch_time | |
| self._end_of_epoch(data, current_epoch, current_iteration) | |
| # Save everything to the checkpoint. | |
| if current_iteration % self.cfg.snapshot_save_iter == 0: | |
| if current_epoch >= self.cfg.snapshot_save_start_epoch: | |
| self.save_checkpoint(current_epoch, current_iteration) | |
| # Compute metrics. | |
| if current_iteration % self.cfg.metrics_iter == 0: | |
| self.save_image(self._get_save_path('images', 'jpg'), data) | |
| self.write_metrics() | |
| def pre_process(self, data): | |
| r"""Custom data pre-processing function. Utilize this function if you | |
| need to preprocess your data before sending it to the generator and | |
| discriminator. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| """ | |
| def recalculate_batch_norm_statistics(self, data_loader, averaged=True): | |
| r"""Update the statistics in the moving average model. | |
| Args: | |
| data_loader (torch.utils.data.DataLoader): Data loader for | |
| estimating the statistics. | |
| averaged (Boolean): True/False, we recalculate batch norm statistics for EMA/regular | |
| """ | |
| if not self.cfg.trainer.model_average_config.enabled: | |
| return | |
| if averaged: | |
| net_G = self.net_G.module.averaged_model | |
| else: | |
| net_G = self.net_G_module | |
| model_average_iteration = \ | |
| self.cfg.trainer.model_average_config.num_batch_norm_estimation_iterations | |
| if model_average_iteration == 0: | |
| return | |
| with torch.no_grad(): | |
| # Accumulate bn stats.. | |
| net_G.train() | |
| # Reset running stats. | |
| net_G.apply(reset_batch_norm) | |
| for cal_it, cal_data in enumerate(data_loader): | |
| if cal_it >= model_average_iteration: | |
| print('Done with {} iterations of updating batch norm ' | |
| 'statistics'.format(model_average_iteration)) | |
| break | |
| cal_data = to_device(cal_data, 'cuda') | |
| cal_data = self.pre_process(cal_data) | |
| # Averaging over all batches | |
| net_G.apply(calibrate_batch_norm_momentum) | |
| net_G(cal_data) | |
| def save_image(self, path, data): | |
| r"""Compute visualization images and save them to the disk. | |
| Args: | |
| path (str): Location of the file. | |
| data (dict): Data used for the current iteration. | |
| """ | |
| self.net_G.eval() | |
| vis_images = self._get_visualizations(data) | |
| if is_master() and vis_images is not None: | |
| vis_images = torch.cat( | |
| [img for img in vis_images if img is not None], dim=3).float() | |
| vis_images = (vis_images + 1) / 2 | |
| print('Save output images to {}'.format(path)) | |
| vis_images.clamp_(0, 1) | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| image_grid = torchvision.utils.make_grid( | |
| vis_images, nrow=1, padding=0, normalize=False) | |
| if self.cfg.trainer.image_to_tensorboard: | |
| self.image_meter.write_image(image_grid, self.current_iteration) | |
| torchvision.utils.save_image(image_grid, path, nrow=1) | |
| wandb.log({os.path.splitext(os.path.basename(path))[0]: [wandb.Image(path)]}) | |
| def write_metrics(self): | |
| r"""Write metrics to the tensorboard.""" | |
| cur_fid = self._compute_fid() | |
| if cur_fid is not None: | |
| if self.best_fid is not None: | |
| self.best_fid = min(self.best_fid, cur_fid) | |
| else: | |
| self.best_fid = cur_fid | |
| metric_dict = {'FID': cur_fid, 'best_FID': self.best_fid} | |
| self._write_to_meters(metric_dict, self.metric_meters, reduce=False) | |
| self._flush_meters(self.metric_meters) | |
| def _get_save_path(self, subdir, ext): | |
| r"""Get the image save path. | |
| Args: | |
| subdir (str): Sub-directory under the main directory for saving | |
| the outputs. | |
| ext (str): Filename extension for the image (e.g., jpg, png, ...). | |
| Return: | |
| (str): image filename to be used to save the visualization results. | |
| """ | |
| subdir_path = os.path.join(self.cfg.logdir, subdir) | |
| if not os.path.exists(subdir_path): | |
| os.makedirs(subdir_path, exist_ok=True) | |
| return os.path.join( | |
| subdir_path, 'epoch_{:05}_iteration_{:09}.{}'.format( | |
| self.current_epoch, self.current_iteration, ext)) | |
| def _get_outputs(self, net_D_output, real=True): | |
| r"""Return output values. Note that when the gan mode is relativistic. | |
| It will do the difference before returning. | |
| Args: | |
| net_D_output (dict): | |
| real_outputs (tensor): Real output values. | |
| fake_outputs (tensor): Fake output values. | |
| real (bool): Return real or fake. | |
| """ | |
| def _get_difference(a, b): | |
| r"""Get difference between two lists of tensors or two tensors. | |
| Args: | |
| a: list of tensors or tensor | |
| b: list of tensors or tensor | |
| """ | |
| out = list() | |
| for x, y in zip(a, b): | |
| if isinstance(x, list): | |
| res = _get_difference(x, y) | |
| else: | |
| res = x - y | |
| out.append(res) | |
| return out | |
| if real: | |
| if self.cfg.trainer.gan_relativistic: | |
| return _get_difference(net_D_output['real_outputs'], net_D_output['fake_outputs']) | |
| else: | |
| return net_D_output['real_outputs'] | |
| else: | |
| if self.cfg.trainer.gan_relativistic: | |
| return _get_difference(net_D_output['fake_outputs'], net_D_output['real_outputs']) | |
| else: | |
| return net_D_output['fake_outputs'] | |
| def _start_of_epoch(self, current_epoch): | |
| r"""Operations to do before starting an epoch. | |
| Args: | |
| current_epoch (int): Current number of epoch. | |
| """ | |
| pass | |
| def _start_of_iteration(self, data, current_iteration): | |
| r"""Operations to do before starting an iteration. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| current_iteration (int): Current epoch number. | |
| Returns: | |
| (dict): Data used for the current iteration. They might be | |
| processed by the custom _start_of_iteration function. | |
| """ | |
| return data | |
| def _end_of_iteration(self, data, current_epoch, current_iteration): | |
| r"""Operations to do after an iteration. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| current_epoch (int): Current number of epoch. | |
| current_iteration (int): Current epoch number. | |
| """ | |
| pass | |
| def _end_of_epoch(self, data, current_epoch, current_iteration): | |
| r"""Operations to do after an epoch. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| current_epoch (int): Current number of epoch. | |
| current_iteration (int): Current epoch number. | |
| """ | |
| pass | |
| def _get_visualizations(self, data): | |
| r"""Compute visualization outputs. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| """ | |
| return None | |
| def _compute_fid(self): | |
| r"""FID computation function to be overloaded.""" | |
| return None | |
| def _init_loss(self, cfg): | |
| r"""Every trainer should implement its own init loss function.""" | |
| raise NotImplementedError | |
| def gen_update(self, data): | |
| r"""Update the generator. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| """ | |
| update_finished = False | |
| while not update_finished: | |
| # Set requires_grad flags. | |
| requires_grad(self.net_G_module, True) | |
| requires_grad(self.net_D, False) | |
| # Compute the loss. | |
| self._time_before_forward() | |
| with autocast(enabled=self.cfg.trainer.amp_config.enabled): | |
| total_loss = self.gen_forward(data) | |
| if total_loss is None: | |
| return | |
| # Zero-grad and backpropagate the loss. | |
| self.opt_G.zero_grad(set_to_none=True) | |
| self._time_before_backward() | |
| self.scaler_G.scale(total_loss).backward() | |
| # Optionally clip gradient norm. | |
| if hasattr(self.cfg.gen_opt, 'clip_grad_norm'): | |
| self.scaler_G.unscale_(self.opt_G) | |
| total_norm = torch.nn.utils.clip_grad_norm_( | |
| self.net_G_module.parameters(), | |
| self.cfg.gen_opt.clip_grad_norm | |
| ) | |
| self.gen_grad_norm = total_norm | |
| if torch.isfinite(total_norm) and \ | |
| total_norm > self.cfg.gen_opt.clip_grad_norm: | |
| # print(f"Gradient norm of the generator ({total_norm}) " | |
| # f"too large.") | |
| if getattr(self.cfg.gen_opt, 'skip_grad', False): | |
| print(f"Skip gradient update.") | |
| self.opt_G.zero_grad(set_to_none=True) | |
| self.scaler_G.step(self.opt_G) | |
| self.scaler_G.update() | |
| break | |
| # else: | |
| # print(f"Clip gradient norm to " | |
| # f"{self.cfg.gen_opt.clip_grad_norm}.") | |
| # Perform an optimizer step. | |
| self._time_before_step() | |
| self.scaler_G.step(self.opt_G) | |
| self.scaler_G.update() | |
| # Whether the step above was skipped. | |
| if self.last_step_count_G == self.opt_G._step_count: | |
| print("Generator overflowed!") | |
| if not torch.isfinite(total_loss): | |
| print("Generator loss is not finite. Skip this iteration!") | |
| update_finished = True | |
| else: | |
| self.last_step_count_G = self.opt_G._step_count | |
| update_finished = True | |
| self._extra_gen_step(data) | |
| # Update model average. | |
| self._time_before_model_avg() | |
| if self.cfg.trainer.model_average_config.enabled: | |
| self.net_G.module.update_average() | |
| self._detach_losses() | |
| self._time_before_leave_gen() | |
| def gen_forward(self, data): | |
| r"""Every trainer should implement its own generator forward.""" | |
| raise NotImplementedError | |
| def _extra_gen_step(self, data): | |
| pass | |
| def dis_update(self, data): | |
| r"""Update the discriminator. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| """ | |
| update_finished = False | |
| while not update_finished: | |
| # Set requires_grad flags. | |
| requires_grad(self.net_G_module, False) | |
| requires_grad(self.net_D, True) | |
| # Compute the loss. | |
| self._time_before_forward() | |
| with autocast(enabled=self.cfg.trainer.amp_config.enabled): | |
| total_loss = self.dis_forward(data) | |
| if total_loss is None: | |
| return | |
| # Zero-grad and backpropagate the loss. | |
| self.opt_D.zero_grad(set_to_none=True) | |
| self._time_before_backward() | |
| self.scaler_D.scale(total_loss).backward() | |
| # Optionally clip gradient norm. | |
| if hasattr(self.cfg.dis_opt, 'clip_grad_norm'): | |
| self.scaler_D.unscale_(self.opt_D) | |
| total_norm = torch.nn.utils.clip_grad_norm_( | |
| self.net_D.parameters(), self.cfg.dis_opt.clip_grad_norm | |
| ) | |
| self.dis_grad_norm = total_norm | |
| if torch.isfinite(total_norm) and \ | |
| total_norm > self.cfg.dis_opt.clip_grad_norm: | |
| print(f"Gradient norm of the discriminator ({total_norm}) " | |
| f"too large.") | |
| if getattr(self.cfg.dis_opt, 'skip_grad', False): | |
| print(f"Skip gradient update.") | |
| self.opt_D.zero_grad(set_to_none=True) | |
| self.scaler_D.step(self.opt_D) | |
| self.scaler_D.update() | |
| continue | |
| else: | |
| print(f"Clip gradient norm to " | |
| f"{self.cfg.dis_opt.clip_grad_norm}.") | |
| # Perform an optimizer step. | |
| self._time_before_step() | |
| self.scaler_D.step(self.opt_D) | |
| self.scaler_D.update() | |
| # Whether the step above was skipped. | |
| if self.last_step_count_D == self.opt_D._step_count: | |
| print("Discriminator overflowed!") | |
| if not torch.isfinite(total_loss): | |
| print("Discriminator loss is not finite. " | |
| "Skip this iteration!") | |
| update_finished = True | |
| else: | |
| self.last_step_count_D = self.opt_D._step_count | |
| update_finished = True | |
| self._extra_dis_step(data) | |
| self._detach_losses() | |
| self._time_before_leave_dis() | |
| def dis_forward(self, data): | |
| r"""Every trainer should implement its own discriminator forward.""" | |
| raise NotImplementedError | |
| def _extra_dis_step(self, data): | |
| pass | |
| def test(self, data_loader, output_dir, inference_args): | |
| r"""Compute results images for a batch of input data and save the | |
| results in the specified folder. | |
| Args: | |
| data_loader (torch.utils.data.DataLoader): PyTorch dataloader. | |
| output_dir (str): Target location for saving the output image. | |
| """ | |
| if self.cfg.trainer.model_average_config.enabled: | |
| net_G = self.net_G.module.averaged_model | |
| else: | |
| net_G = self.net_G.module | |
| net_G.eval() | |
| print('# of samples %d' % len(data_loader)) | |
| for it, data in enumerate(tqdm(data_loader)): | |
| data = self.start_of_iteration(data, current_iteration=-1) | |
| with torch.no_grad(): | |
| output_images, file_names = \ | |
| net_G.inference(data, **vars(inference_args)) | |
| for output_image, file_name in zip(output_images, file_names): | |
| fullname = os.path.join(output_dir, file_name + '.jpg') | |
| output_image = tensor2pilimage(output_image.clamp_(-1, 1), | |
| minus1to1_normalized=True) | |
| save_pilimage_in_jpeg(fullname, output_image) | |
| def _get_total_loss(self, gen_forward): | |
| r"""Return the total loss to be backpropagated. | |
| Args: | |
| gen_forward (bool): If ``True``, backpropagates the generator loss, | |
| otherwise the discriminator loss. | |
| """ | |
| losses = self.gen_losses if gen_forward else self.dis_losses | |
| total_loss = torch.tensor(0., device=torch.device('cuda')) | |
| # Iterates over all possible losses. | |
| for loss_name in self.weights: | |
| # If it is for the current model (gen/dis). | |
| if loss_name in losses: | |
| # Multiply it with the corresponding weight | |
| # and add it to the total loss. | |
| total_loss += losses[loss_name] * self.weights[loss_name] | |
| losses['total'] = total_loss # logging purpose | |
| return total_loss | |
| def _detach_losses(self): | |
| r"""Detach all logging variables to prevent potential memory leak.""" | |
| for loss_name in self.gen_losses: | |
| self.gen_losses[loss_name] = self.gen_losses[loss_name].detach() | |
| for loss_name in self.dis_losses: | |
| self.dis_losses[loss_name] = self.dis_losses[loss_name].detach() | |
| def _time_before_forward(self): | |
| r""" | |
| Record time before applying forward. | |
| """ | |
| if self.cfg.speed_benchmark: | |
| torch.cuda.synchronize() | |
| self.forw_time = time.time() | |
| def _time_before_loss(self): | |
| r""" | |
| Record time before computing loss. | |
| """ | |
| if self.cfg.speed_benchmark: | |
| torch.cuda.synchronize() | |
| self.loss_time = time.time() | |
| def _time_before_backward(self): | |
| r""" | |
| Record time before applying backward. | |
| """ | |
| if self.cfg.speed_benchmark: | |
| torch.cuda.synchronize() | |
| self.back_time = time.time() | |
| def _time_before_step(self): | |
| r""" | |
| Record time before updating the weights | |
| """ | |
| if self.cfg.speed_benchmark: | |
| torch.cuda.synchronize() | |
| self.step_time = time.time() | |
| def _time_before_model_avg(self): | |
| r""" | |
| Record time before applying model average. | |
| """ | |
| if self.cfg.speed_benchmark: | |
| torch.cuda.synchronize() | |
| self.avg_time = time.time() | |
| def _time_before_leave_gen(self): | |
| r""" | |
| Record forward, backward, loss, and model average time for the | |
| generator update. | |
| """ | |
| if self.cfg.speed_benchmark: | |
| torch.cuda.synchronize() | |
| end_time = time.time() | |
| self.accu_gen_forw_iter_time += self.loss_time - self.forw_time | |
| self.accu_gen_loss_iter_time += self.back_time - self.loss_time | |
| self.accu_gen_back_iter_time += self.step_time - self.back_time | |
| self.accu_gen_step_iter_time += self.avg_time - self.step_time | |
| self.accu_gen_avg_iter_time += end_time - self.avg_time | |
| def _time_before_leave_dis(self): | |
| r""" | |
| Record forward, backward, loss time for the discriminator update. | |
| """ | |
| if self.cfg.speed_benchmark: | |
| torch.cuda.synchronize() | |
| end_time = time.time() | |
| self.accu_dis_forw_iter_time += self.loss_time - self.forw_time | |
| self.accu_dis_loss_iter_time += self.back_time - self.loss_time | |
| self.accu_dis_back_iter_time += self.step_time - self.back_time | |
| self.accu_dis_step_iter_time += end_time - self.step_time | |
| def _save_checkpoint(cfg, | |
| net_G, net_D, | |
| opt_G, opt_D, | |
| sch_G, sch_D, | |
| current_epoch, current_iteration): | |
| r"""Save network weights, optimizer parameters, scheduler parameters | |
| in the checkpoint. | |
| Args: | |
| cfg (obj): Global configuration. | |
| net_D (obj): Discriminator network. | |
| opt_G (obj): Optimizer for the generator network. | |
| opt_D (obj): Optimizer for the discriminator network. | |
| sch_G (obj): Scheduler for the generator optimizer. | |
| sch_D (obj): Scheduler for the discriminator optimizer. | |
| current_epoch (int): Current epoch. | |
| current_iteration (int): Current iteration. | |
| """ | |
| latest_checkpoint_path = 'epoch_{:05}_iteration_{:09}_checkpoint.pt'.format( | |
| current_epoch, current_iteration) | |
| save_path = os.path.join(cfg.logdir, latest_checkpoint_path) | |
| torch.save( | |
| { | |
| 'net_G': net_G.state_dict(), | |
| 'net_D': net_D.state_dict(), | |
| 'opt_G': opt_G.state_dict(), | |
| 'opt_D': opt_D.state_dict(), | |
| 'sch_G': sch_G.state_dict(), | |
| 'sch_D': sch_D.state_dict(), | |
| 'current_epoch': current_epoch, | |
| 'current_iteration': current_iteration, | |
| }, | |
| save_path, | |
| ) | |
| fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt') | |
| with open(fn, 'wt') as f: | |
| f.write('latest_checkpoint: %s' % latest_checkpoint_path) | |
| print('Save checkpoint to {}'.format(save_path)) | |
| return save_path | |