Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import torch | |
| from torch import nn | |
| from imaginaire.evaluation import compute_fid | |
| from imaginaire.losses import GANLoss, PerceptualLoss # GaussianKLLoss | |
| from imaginaire.trainers.base import BaseTrainer | |
| class Trainer(BaseTrainer): | |
| r"""Reimplementation of the UNIT (https://arxiv.org/abs/1703.00848) | |
| algorithm. | |
| Args: | |
| cfg (obj): Global configuration. | |
| net_G (obj): Generator network. | |
| net_D (obj): Discriminator network. | |
| opt_G (obj): Optimizer for the generator network. | |
| opt_D (obj): Optimizer for the discriminator network. | |
| sch_G (obj): Scheduler for the generator optimizer. | |
| sch_D (obj): Scheduler for the discriminator optimizer. | |
| train_data_loader (obj): Train data loader. | |
| val_data_loader (obj): Validation data loader. | |
| """ | |
| def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, | |
| train_data_loader, val_data_loader): | |
| super().__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, | |
| train_data_loader, val_data_loader) | |
| self.best_fid_a = None | |
| self.best_fid_b = None | |
| def _init_loss(self, cfg): | |
| r"""Initialize loss terms. In UNIT, we have several loss terms | |
| including the GAN loss, the image reconstruction loss, the cycle | |
| reconstruction loss, and the gaussian kl loss. We also have an | |
| optional perceptual loss. A user can choose to have the gradient | |
| penalty loss too. | |
| Args: | |
| cfg (obj): Global configuration. | |
| """ | |
| self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode) | |
| # self.criteria['gaussian_kl'] = GaussianKLLoss() | |
| self.criteria['image_recon'] = nn.L1Loss() | |
| self.criteria['cycle_recon'] = nn.L1Loss() | |
| if getattr(cfg.trainer.loss_weight, 'perceptual', 0) > 0: | |
| self.criteria['perceptual'] = \ | |
| PerceptualLoss(network=cfg.trainer.perceptual_mode, | |
| layers=cfg.trainer.perceptual_layers) | |
| for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items(): | |
| if loss_weight > 0: | |
| self.weights[loss_name] = loss_weight | |
| def gen_forward(self, data): | |
| r"""Compute the loss for UNIT generator. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| """ | |
| cycle_recon = 'cycle_recon' in self.weights | |
| perceptual = 'perceptual' in self.weights | |
| net_G_output = self.net_G(data, cycle_recon=cycle_recon) | |
| net_D_output = self.net_D(data, net_G_output, real=False) | |
| self._time_before_loss() | |
| # GAN loss | |
| self.gen_losses['gan_a'] = self.criteria['gan']( | |
| net_D_output['out_ba'], True, dis_update=False) | |
| self.gen_losses['gan_b'] = self.criteria['gan']( | |
| net_D_output['out_ab'], True, dis_update=False) | |
| self.gen_losses['gan'] = \ | |
| self.gen_losses['gan_a'] + self.gen_losses['gan_b'] | |
| # Perceptual loss | |
| if perceptual: | |
| self.gen_losses['perceptual_a'] = \ | |
| self.criteria['perceptual'](net_G_output['images_ab'], | |
| data['images_a']) | |
| self.gen_losses['perceptual_b'] = \ | |
| self.criteria['perceptual'](net_G_output['images_ba'], | |
| data['images_b']) | |
| self.gen_losses['perceptual'] = \ | |
| self.gen_losses['perceptual_a'] + \ | |
| self.gen_losses['perceptual_b'] | |
| # Image reconstruction loss | |
| self.gen_losses['image_recon'] = \ | |
| self.criteria['image_recon'](net_G_output['images_aa'], | |
| data['images_a']) + \ | |
| self.criteria['image_recon'](net_G_output['images_bb'], | |
| data['images_b']) | |
| """ | |
| # KL loss | |
| self.gen_losses['gaussian_kl'] = \ | |
| self.criteria['gaussian_kl'](net_G_output['content_mu_a']) + \ | |
| self.criteria['gaussian_kl'](net_G_output['content_mu_b']) + \ | |
| self.criteria['gaussian_kl'](net_G_output['content_mu_a_recon']) + \ | |
| self.criteria['gaussian_kl'](net_G_output['content_mu_b_recon']) | |
| """ | |
| # Cycle reconstruction loss | |
| if cycle_recon: | |
| self.gen_losses['cycle_recon_aba'] = \ | |
| self.criteria['cycle_recon'](net_G_output['images_aba'], | |
| data['images_a']) | |
| self.gen_losses['cycle_recon_bab'] = \ | |
| self.criteria['cycle_recon'](net_G_output['images_bab'], | |
| data['images_b']) | |
| self.gen_losses['cycle_recon'] = \ | |
| self.gen_losses['cycle_recon_aba'] + \ | |
| self.gen_losses['cycle_recon_bab'] | |
| # Compute total loss | |
| total_loss = self._get_total_loss(gen_forward=True) | |
| return total_loss | |
| def dis_forward(self, data): | |
| r"""Compute the loss for UNIT discriminator. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| """ | |
| with torch.no_grad(): | |
| net_G_output = self.net_G(data, image_recon=False, | |
| cycle_recon=False) | |
| net_G_output['images_ba'].requires_grad = True | |
| net_G_output['images_ab'].requires_grad = True | |
| net_D_output = self.net_D(data, net_G_output) | |
| self._time_before_loss() | |
| # GAN loss. | |
| self.dis_losses['gan_a'] = \ | |
| self.criteria['gan'](net_D_output['out_a'], True) + \ | |
| self.criteria['gan'](net_D_output['out_ba'], False) | |
| self.dis_losses['gan_b'] = \ | |
| self.criteria['gan'](net_D_output['out_b'], True) + \ | |
| self.criteria['gan'](net_D_output['out_ab'], False) | |
| self.dis_losses['gan'] = \ | |
| self.dis_losses['gan_a'] + self.dis_losses['gan_b'] | |
| # Compute total loss | |
| total_loss = self._get_total_loss(gen_forward=False) | |
| return total_loss | |
| def _get_visualizations(self, data): | |
| r"""Compute visualization image. | |
| Args: | |
| data (dict): The current batch. | |
| """ | |
| if self.cfg.trainer.model_average_config.enabled: | |
| net_G_for_evaluation = self.net_G.module.averaged_model | |
| else: | |
| net_G_for_evaluation = self.net_G | |
| with torch.no_grad(): | |
| net_G_output = net_G_for_evaluation(data) | |
| vis_images = [data['images_a'], | |
| data['images_b'], | |
| net_G_output['images_aa'], | |
| net_G_output['images_bb'], | |
| net_G_output['images_ab'], | |
| net_G_output['images_ba'], | |
| net_G_output['images_aba'], | |
| net_G_output['images_bab']] | |
| return vis_images | |
| def write_metrics(self): | |
| r"""Compute metrics and save them to tensorboard""" | |
| cur_fid_a, cur_fid_b = self._compute_fid() | |
| if self.best_fid_a is not None: | |
| self.best_fid_a = min(self.best_fid_a, cur_fid_a) | |
| else: | |
| self.best_fid_a = cur_fid_a | |
| if self.best_fid_b is not None: | |
| self.best_fid_b = min(self.best_fid_b, cur_fid_b) | |
| else: | |
| self.best_fid_b = cur_fid_b | |
| self._write_to_meters({'FID_a': cur_fid_a, | |
| 'best_FID_a': self.best_fid_a, | |
| 'FID_b': cur_fid_b, | |
| 'best_FID_b': self.best_fid_b}, | |
| self.metric_meters) | |
| self._flush_meters(self.metric_meters) | |
| def _compute_fid(self): | |
| r"""Compute FID for both domains. | |
| """ | |
| self.net_G.eval() | |
| if self.cfg.trainer.model_average_config.enabled: | |
| net_G_for_evaluation = self.net_G.module.averaged_model | |
| else: | |
| net_G_for_evaluation = self.net_G | |
| fid_a_path = self._get_save_path('fid_a', 'npy') | |
| fid_b_path = self._get_save_path('fid_b', 'npy') | |
| fid_value_a = compute_fid(fid_a_path, self.val_data_loader, | |
| net_G_for_evaluation, 'images_a', 'images_ba') | |
| fid_value_b = compute_fid(fid_b_path, self.val_data_loader, | |
| net_G_for_evaluation, 'images_b', 'images_ab') | |
| print('Epoch {:05}, Iteration {:09}, FID a {}, FID b {}'.format( | |
| self.current_epoch, self.current_iteration, | |
| fid_value_a, fid_value_b)) | |
| return fid_value_a, fid_value_b | |