Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import torch.nn as nn | |
| class DictLoss(nn.Module): | |
| def __init__(self, criterion='l1'): | |
| super(DictLoss, self).__init__() | |
| if criterion == 'l1': | |
| self.criterion = nn.L1Loss() | |
| elif criterion == 'l2' or criterion == 'mse': | |
| self.criterion = nn.MSELoss() | |
| else: | |
| raise ValueError('Criterion %s is not recognized' % criterion) | |
| def forward(self, fake, real): | |
| """Return the target vector for the l1/l2 loss computation. | |
| Args: | |
| fake (dict, list or tuple): Discriminator features of fake images. | |
| real (dict, list or tuple): Discriminator features of real images. | |
| Returns: | |
| loss (tensor): Loss value. | |
| """ | |
| loss = 0 | |
| if type(fake) == dict: | |
| for key in fake.keys(): | |
| loss += self.criterion(fake[key], real[key].detach()) | |
| elif type(fake) == list or type(fake) == tuple: | |
| for f, r in zip(fake, real): | |
| loss += self.criterion(f, r.detach()) | |
| else: | |
| loss += self.criterion(fake, real.detach()) | |
| return loss | |