Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class GANLoss(nn.Module): | |
| def __init__(self, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor, opt=None): | |
| super(GANLoss, self).__init__() | |
| self.real_label = target_real_label | |
| self.fake_label = target_fake_label | |
| self.real_label_tensor = None | |
| self.fake_label_tensor = None | |
| self.zero_tensor = None | |
| self.Tensor = tensor | |
| self.opt = opt | |
| def get_target_tensor(self, input, target_is_real): | |
| if target_is_real: | |
| return torch.ones_like(input).detach() | |
| else: | |
| return torch.zeros_like(input).detach() | |
| def get_zero_tensor(self, input): | |
| return torch.zeros_like(input).detach() | |
| def loss(self, inputs, target_is_real, for_discriminator=True): | |
| target_tensor = self.get_target_tensor(inputs, target_is_real) | |
| loss = F.binary_cross_entropy_with_logits(inputs, target_tensor) | |
| return loss | |
| def __call__(self, inputs, target_is_real, for_discriminator=True): | |
| # computing loss is a bit complicated because |input| may not be | |
| # a tensor, but list of tensors in case of multiscale discriminator | |
| if isinstance(inputs, list): | |
| loss = 0 | |
| for pred_i in inputs: | |
| if isinstance(pred_i, list): | |
| pred_i = pred_i[-1] | |
| loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) | |
| bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) | |
| new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) | |
| loss += new_loss | |
| return loss / len(inputs) | |
| else: | |
| return self.loss(inputs, target_is_real, for_discriminator) | |