Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| class TVLoss(nn.Module): | |
| def __init__(self,TVLoss_weight=1): | |
| super(TVLoss,self).__init__() | |
| self.TVLoss_weight = TVLoss_weight | |
| def forward(self,x): | |
| batch_size = x.size()[0] | |
| h_x = x.size()[2] | |
| w_x = x.size()[3] | |
| count_h = self._tensor_size(x[:,:,1:,:]) | |
| count_w = self._tensor_size(x[:,:,:,1:]) | |
| h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() | |
| w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() | |
| return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size | |
| def _tensor_size(self,t): | |
| return t.size()[1]*t.size()[2]*t.size()[3] | |
| class hinge_loss(nn.Module): | |
| def __init__(self): | |
| super(hinge_loss, self).__init__() | |
| def forward(self, dis_fake, dis_real): | |
| loss_real = torch.mean(F.relu(1. - dis_real)) | |
| loss_fake = torch.mean(F.relu(1. + dis_fake)) | |
| return loss_real + loss_fake | |
| class hinge_loss_G(nn.Module): | |
| def __init__(self): | |
| super(hinge_loss_G, self).__init__() | |
| def forward(self, dis_fake): | |
| loss_fake = -torch.mean(dis_fake) | |
| return loss_fake |