Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from imaginaire.utils.distributed import master_only_print as print | |
| def fuse_math_min_mean_pos(x): | |
| r"""Fuse operation min mean for hinge loss computation of positive | |
| samples""" | |
| minval = torch.min(x - 1, x * 0) | |
| loss = -torch.mean(minval) | |
| return loss | |
| def fuse_math_min_mean_neg(x): | |
| r"""Fuse operation min mean for hinge loss computation of negative | |
| samples""" | |
| minval = torch.min(-x - 1, x * 0) | |
| loss = -torch.mean(minval) | |
| return loss | |
| class GANLoss(nn.Module): | |
| r"""GAN loss constructor. | |
| Args: | |
| gan_mode (str): Type of GAN loss. ``'hinge'``, ``'least_square'``, | |
| ``'non_saturated'``, ``'wasserstein'``. | |
| target_real_label (float): The desired output label for real images. | |
| target_fake_label (float): The desired output label for fake images. | |
| decay_k (float): The decay factor per epoch for top-k training. | |
| min_k (float): The minimum percentage of samples to select. | |
| separate_topk (bool): If ``True``, selects top-k for each sample | |
| separately, otherwise selects top-k among all samples. | |
| """ | |
| def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, | |
| decay_k=1., min_k=1., separate_topk=False): | |
| super(GANLoss, self).__init__() | |
| self.real_label = target_real_label | |
| self.fake_label = target_fake_label | |
| self.real_label_tensor = None | |
| self.fake_label_tensor = None | |
| self.gan_mode = gan_mode | |
| self.decay_k = decay_k | |
| self.min_k = min_k | |
| self.separate_topk = separate_topk | |
| self.register_buffer('k', torch.tensor(1.0)) | |
| print('GAN mode: %s' % gan_mode) | |
| def forward(self, dis_output, t_real, dis_update=True, reduce=True): | |
| r"""GAN loss computation. | |
| Args: | |
| dis_output (tensor or list of tensors): Discriminator outputs. | |
| t_real (bool): If ``True``, uses the real label as target, otherwise uses the fake label as target. | |
| dis_update (bool): If ``True``, the loss will be used to update the discriminator, otherwise the generator. | |
| reduce (bool): If ``True``, when a list of discriminator outputs are provided, it will return the average | |
| of all losses, otherwise it will return a list of losses. | |
| Returns: | |
| loss (tensor): Loss value. | |
| """ | |
| if isinstance(dis_output, list): | |
| # For multi-scale discriminators. | |
| # In this implementation, the loss is first averaged for each scale | |
| # (batch size and number of locations) then averaged across scales, | |
| # so that the gradient is not dominated by the discriminator that | |
| # has the most output values (highest resolution). | |
| losses = [] | |
| for dis_output_i in dis_output: | |
| assert isinstance(dis_output_i, torch.Tensor) | |
| losses.append(self.loss(dis_output_i, t_real, dis_update)) | |
| if reduce: | |
| return torch.mean(torch.stack(losses)) | |
| else: | |
| return losses | |
| else: | |
| return self.loss(dis_output, t_real, dis_update) | |
| def loss(self, dis_output, t_real, dis_update=True): | |
| r"""GAN loss computation. | |
| Args: | |
| dis_output (tensor): Discriminator outputs. | |
| t_real (bool): If ``True``, uses the real label as target, otherwise | |
| uses the fake label as target. | |
| dis_update (bool): Updating the discriminator or the generator. | |
| Returns: | |
| loss (tensor): Loss value. | |
| """ | |
| if not dis_update: | |
| assert t_real, \ | |
| "The target should be real when updating the generator." | |
| if not dis_update and self.k < 1: | |
| r""" | |
| Use top-k training: | |
| "Top-k Training of GANs: Improving GAN Performance by Throwing | |
| Away Bad Samples" | |
| Here, each sample may have multiple discriminator output values | |
| (patch discriminator). We could either select top-k for each sample | |
| separately (when ``self.separate_topk=True``), or collect values | |
| from all samples and then select top-k (default, when | |
| ``self.separate_topk=False``). | |
| """ | |
| if self.separate_topk: | |
| dis_output = dis_output.view(dis_output.size(0), -1) | |
| else: | |
| dis_output = dis_output.view(-1) | |
| k = math.ceil(self.k * dis_output.size(-1)) | |
| dis_output, _ = torch.topk(dis_output, k) | |
| if self.gan_mode == 'non_saturated': | |
| target_tensor = self.get_target_tensor(dis_output, t_real) | |
| loss = F.binary_cross_entropy_with_logits(dis_output, | |
| target_tensor) | |
| elif self.gan_mode == 'least_square': | |
| target_tensor = self.get_target_tensor(dis_output, t_real) | |
| loss = 0.5 * F.mse_loss(dis_output, target_tensor) | |
| elif self.gan_mode == 'hinge': | |
| if dis_update: | |
| if t_real: | |
| loss = fuse_math_min_mean_pos(dis_output) | |
| else: | |
| loss = fuse_math_min_mean_neg(dis_output) | |
| else: | |
| loss = -torch.mean(dis_output) | |
| elif self.gan_mode == 'wasserstein': | |
| if t_real: | |
| loss = -torch.mean(dis_output) | |
| else: | |
| loss = torch.mean(dis_output) | |
| elif self.gan_mode == 'softplus': | |
| target_tensor = self.get_target_tensor(dis_output, t_real) | |
| loss = F.binary_cross_entropy_with_logits(dis_output, | |
| target_tensor) | |
| else: | |
| raise ValueError('Unexpected gan_mode {}'.format(self.gan_mode)) | |
| return loss | |
| def get_target_tensor(self, dis_output, t_real): | |
| r"""Return the target vector for the binary cross entropy loss | |
| computation. | |
| Args: | |
| dis_output (tensor): Discriminator outputs. | |
| t_real (bool): If ``True``, uses the real label as target, otherwise | |
| uses the fake label as target. | |
| Returns: | |
| target (tensor): Target tensor vector. | |
| """ | |
| if t_real: | |
| if self.real_label_tensor is None: | |
| self.real_label_tensor = dis_output.new_tensor(self.real_label) | |
| return self.real_label_tensor.expand_as(dis_output) | |
| else: | |
| if self.fake_label_tensor is None: | |
| self.fake_label_tensor = dis_output.new_tensor(self.fake_label) | |
| return self.fake_label_tensor.expand_as(dis_output) | |
| def topk_anneal(self): | |
| r"""Anneal k after each epoch.""" | |
| if self.decay_k < 1: | |
| # noinspection PyAttributeOutsideInit | |
| self.k.fill_(max(self.decay_k * self.k, self.min_k)) | |
| print("Top-k training: update k to {}.".format(self.k)) | |