Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import torch | |
| from imaginaire.evaluation import compute_fid | |
| from imaginaire.losses import (GANLoss, GaussianKLLoss, | |
| PerceptualLoss) | |
| from imaginaire.trainers.base import BaseTrainer | |
| from imaginaire.utils.misc import random_shift | |
| from imaginaire.utils.distributed import master_only_print as print | |
| from imaginaire.utils.diff_aug import apply_diff_aug | |
| class Trainer(BaseTrainer): | |
| r"""Reimplementation of the MUNIT (https://arxiv.org/abs/1804.04732) | |
| algorithm. | |
| Args: | |
| cfg (obj): Global configuration. | |
| net_G (obj): Generator network. | |
| net_D (obj): Discriminator network. | |
| opt_G (obj): Optimizer for the generator network. | |
| opt_D (obj): Optimizer for the discriminator network. | |
| sch_G (obj): Scheduler for the generator optimizer. | |
| sch_D (obj): Scheduler for the discriminator optimizer. | |
| train_data_loader (obj): Train data loader. | |
| val_data_loader (obj): Validation data loader. | |
| """ | |
| def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, | |
| train_data_loader, val_data_loader): | |
| super().__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, | |
| train_data_loader, val_data_loader) | |
| self.gan_recon = getattr(cfg.trainer, 'gan_recon', False) | |
| self.best_fid_a = None | |
| self.best_fid_b = None | |
| def _init_loss(self, cfg): | |
| r"""Initialize loss terms. In MUNIT, we have several loss terms | |
| including the GAN loss, the image reconstruction loss, the content | |
| reconstruction loss, the style reconstruction loss, the cycle | |
| reconstruction loss. We also have an optional perceptual loss. A user | |
| can choose to have gradient penalty or consistency regularization too. | |
| Args: | |
| cfg (obj): Global configuration. | |
| """ | |
| self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode) | |
| self.criteria['kl'] = GaussianKLLoss() | |
| self.criteria['image_recon'] = torch.nn.L1Loss() | |
| if getattr(cfg.trainer.loss_weight, 'perceptual', 0) > 0: | |
| self.criteria['perceptual'] = \ | |
| PerceptualLoss(network=cfg.trainer.perceptual_mode, | |
| layers=cfg.trainer.perceptual_layers) | |
| for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items(): | |
| if loss_weight > 0: | |
| self.weights[loss_name] = loss_weight | |
| def gen_forward(self, data): | |
| r"""Compute the loss for MUNIT generator. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| """ | |
| cycle_recon = 'cycle_recon' in self.weights | |
| image_recon = 'image_recon' in self.weights | |
| perceptual = 'perceptual' in self.weights | |
| within_latent_recon = 'style_recon_within' in self.weights or \ | |
| 'content_recon_within' in self.weights | |
| net_G_output = self.net_G(data, | |
| image_recon=image_recon, | |
| cycle_recon=cycle_recon, | |
| within_latent_recon=within_latent_recon) | |
| # Differentiable augmentation. | |
| keys = ['images_ab', 'images_ba'] | |
| if self.gan_recon: | |
| keys += ['images_aa', 'images_bb'] | |
| net_D_output = self.net_D(data, | |
| apply_diff_aug( | |
| net_G_output, keys, self.aug_policy), | |
| real=False, | |
| gan_recon=self.gan_recon) | |
| self._time_before_loss() | |
| # GAN loss | |
| if self.gan_recon: | |
| self.gen_losses['gan_a'] = \ | |
| 0.5 * (self.criteria['gan'](net_D_output['out_ba'], | |
| True, dis_update=False) + | |
| self.criteria['gan'](net_D_output['out_aa'], | |
| True, dis_update=False)) | |
| self.gen_losses['gan_b'] = \ | |
| 0.5 * (self.criteria['gan'](net_D_output['out_ab'], | |
| True, dis_update=False) + | |
| self.criteria['gan'](net_D_output['out_bb'], | |
| True, dis_update=False)) | |
| else: | |
| self.gen_losses['gan_a'] = self.criteria['gan']( | |
| net_D_output['out_ba'], True, dis_update=False) | |
| self.gen_losses['gan_b'] = self.criteria['gan']( | |
| net_D_output['out_ab'], True, dis_update=False) | |
| self.gen_losses['gan'] = \ | |
| self.gen_losses['gan_a'] + self.gen_losses['gan_b'] | |
| # Perceptual loss | |
| if perceptual: | |
| self.gen_losses['perceptual_a'] = \ | |
| self.criteria['perceptual'](net_G_output['images_ab'], | |
| data['images_a']) | |
| self.gen_losses['perceptual_b'] = \ | |
| self.criteria['perceptual'](net_G_output['images_ba'], | |
| data['images_b']) | |
| self.gen_losses['perceptual'] = \ | |
| self.gen_losses['perceptual_a'] + \ | |
| self.gen_losses['perceptual_b'] | |
| # Image reconstruction loss | |
| if image_recon: | |
| self.gen_losses['image_recon'] = \ | |
| self.criteria['image_recon'](net_G_output['images_aa'], | |
| data['images_a']) + \ | |
| self.criteria['image_recon'](net_G_output['images_bb'], | |
| data['images_b']) | |
| # Style reconstruction loss | |
| self.gen_losses['style_recon_a'] = torch.abs( | |
| net_G_output['style_ba'] - | |
| net_G_output['style_a_rand']).mean() | |
| self.gen_losses['style_recon_b'] = torch.abs( | |
| net_G_output['style_ab'] - | |
| net_G_output['style_b_rand']).mean() | |
| self.gen_losses['style_recon'] = \ | |
| self.gen_losses['style_recon_a'] + self.gen_losses['style_recon_b'] | |
| if within_latent_recon: | |
| self.gen_losses['style_recon_aa'] = torch.abs( | |
| net_G_output['style_aa'] - | |
| net_G_output['style_a'].detach()).mean() | |
| self.gen_losses['style_recon_bb'] = torch.abs( | |
| net_G_output['style_bb'] - | |
| net_G_output['style_b'].detach()).mean() | |
| self.gen_losses['style_recon_within'] = \ | |
| self.gen_losses['style_recon_aa'] + \ | |
| self.gen_losses['style_recon_bb'] | |
| # Content reconstruction loss | |
| self.gen_losses['content_recon_a'] = torch.abs( | |
| net_G_output['content_ab'] - | |
| net_G_output['content_a'].detach()).mean() | |
| self.gen_losses['content_recon_b'] = torch.abs( | |
| net_G_output['content_ba'] - | |
| net_G_output['content_b'].detach()).mean() | |
| self.gen_losses['content_recon'] = \ | |
| self.gen_losses['content_recon_a'] + \ | |
| self.gen_losses['content_recon_b'] | |
| if within_latent_recon: | |
| self.gen_losses['content_recon_aa'] = torch.abs( | |
| net_G_output['content_aa'] - | |
| net_G_output['content_a'].detach()).mean() | |
| self.gen_losses['content_recon_bb'] = torch.abs( | |
| net_G_output['content_bb'] - | |
| net_G_output['content_b'].detach()).mean() | |
| self.gen_losses['content_recon_within'] = \ | |
| self.gen_losses['content_recon_aa'] + \ | |
| self.gen_losses['content_recon_bb'] | |
| # KL loss | |
| self.gen_losses['kl'] = \ | |
| self.criteria['kl'](net_G_output['style_a']) + \ | |
| self.criteria['kl'](net_G_output['style_b']) | |
| # Cycle reconstruction loss | |
| if cycle_recon: | |
| self.gen_losses['cycle_recon'] = \ | |
| torch.abs(net_G_output['images_aba'] - | |
| data['images_a']).mean() + \ | |
| torch.abs(net_G_output['images_bab'] - | |
| data['images_b']).mean() | |
| # Compute total loss | |
| total_loss = self._get_total_loss(gen_forward=True) | |
| return total_loss | |
| def dis_forward(self, data): | |
| r"""Compute the loss for MUNIT discriminator. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| """ | |
| with torch.no_grad(): | |
| net_G_output = self.net_G(data, | |
| image_recon=self.gan_recon, | |
| latent_recon=False, | |
| cycle_recon=False, | |
| within_latent_recon=False) | |
| net_G_output['images_ba'].requires_grad = True | |
| net_G_output['images_ab'].requires_grad = True | |
| # Differentiable augmentation. | |
| keys_fake = ['images_ab', 'images_ba'] | |
| if self.gan_recon: | |
| keys_fake += ['images_aa', 'images_bb'] | |
| keys_real = ['images_a', 'images_b'] | |
| net_D_output = self.net_D( | |
| apply_diff_aug(data, keys_real, self.aug_policy), | |
| apply_diff_aug(net_G_output, keys_fake, self.aug_policy), | |
| gan_recon=self.gan_recon) | |
| self._time_before_loss() | |
| # GAN loss. | |
| self.dis_losses['gan_a'] = \ | |
| self.criteria['gan'](net_D_output['out_a'], True) + \ | |
| self.criteria['gan'](net_D_output['out_ba'], False) | |
| self.dis_losses['gan_b'] = \ | |
| self.criteria['gan'](net_D_output['out_b'], True) + \ | |
| self.criteria['gan'](net_D_output['out_ab'], False) | |
| self.dis_losses['gan'] = \ | |
| self.dis_losses['gan_a'] + self.dis_losses['gan_b'] | |
| # Consistency regularization. | |
| self.dis_losses['consistency_reg'] = \ | |
| torch.tensor(0., device=torch.device('cuda')) | |
| if 'consistency_reg' in self.weights: | |
| data_aug, net_G_output_aug = {}, {} | |
| data_aug['images_a'] = random_shift(data['images_a'].flip(-1)) | |
| data_aug['images_b'] = random_shift(data['images_b'].flip(-1)) | |
| net_G_output_aug['images_ab'] = \ | |
| random_shift(net_G_output['images_ab'].flip(-1)) | |
| net_G_output_aug['images_ba'] = \ | |
| random_shift(net_G_output['images_ba'].flip(-1)) | |
| net_D_output_aug = self.net_D(data_aug, net_G_output_aug) | |
| feature_names = ['fea_ba', 'fea_ab', | |
| 'fea_a', 'fea_b'] | |
| for feature_name in feature_names: | |
| self.dis_losses['consistency_reg'] += \ | |
| torch.pow(net_D_output_aug[feature_name] - | |
| net_D_output[feature_name], 2).mean() | |
| # Compute total loss | |
| total_loss = self._get_total_loss(gen_forward=False) | |
| return total_loss | |
| def _get_visualizations(self, data): | |
| r"""Compute visualization image. | |
| Args: | |
| data (dict): The current batch. | |
| """ | |
| if self.cfg.trainer.model_average_config.enabled: | |
| net_G_for_evaluation = self.net_G.module.averaged_model | |
| else: | |
| net_G_for_evaluation = self.net_G | |
| with torch.no_grad(): | |
| net_G_output = net_G_for_evaluation(data, random_style=False) | |
| net_G_output_random = net_G_for_evaluation(data) | |
| vis_images = [data['images_a'], | |
| data['images_b'], | |
| net_G_output['images_aa'], | |
| net_G_output['images_bb'], | |
| net_G_output['images_ab'], | |
| net_G_output_random['images_ab'], | |
| net_G_output['images_ba'], | |
| net_G_output_random['images_ba'], | |
| net_G_output['images_aba'], | |
| net_G_output['images_bab']] | |
| return vis_images | |
| def write_metrics(self): | |
| r"""Compute metrics and save them to tensorboard""" | |
| cur_fid_a, cur_fid_b = self._compute_fid() | |
| if self.best_fid_a is not None: | |
| self.best_fid_a = min(self.best_fid_a, cur_fid_a) | |
| else: | |
| self.best_fid_a = cur_fid_a | |
| if self.best_fid_b is not None: | |
| self.best_fid_b = min(self.best_fid_b, cur_fid_b) | |
| else: | |
| self.best_fid_b = cur_fid_b | |
| self._write_to_meters({'FID_a': cur_fid_a, | |
| 'best_FID_a': self.best_fid_a, | |
| 'FID_b': cur_fid_b, | |
| 'best_FID_b': self.best_fid_b}, | |
| self.metric_meters) | |
| self._flush_meters(self.metric_meters) | |
| def _compute_fid(self): | |
| r"""Compute FID for both domains. | |
| """ | |
| self.net_G.eval() | |
| if self.cfg.trainer.model_average_config.enabled: | |
| net_G_for_evaluation = self.net_G.module.averaged_model | |
| else: | |
| net_G_for_evaluation = self.net_G | |
| fid_a_path = self._get_save_path('fid_a', 'npy') | |
| fid_b_path = self._get_save_path('fid_b', 'npy') | |
| fid_value_a = compute_fid(fid_a_path, self.val_data_loader, | |
| net_G_for_evaluation, 'images_a', 'images_ba') | |
| fid_value_b = compute_fid(fid_b_path, self.val_data_loader, | |
| net_G_for_evaluation, 'images_b', 'images_ab') | |
| print('Epoch {:05}, Iteration {:09}, FID a {}, FID b {}'.format( | |
| self.current_epoch, self.current_iteration, | |
| fid_value_a, fid_value_b)) | |
| return fid_value_a, fid_value_b | |