Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| from collections import namedtuple | |
| import torch | |
| from torch import nn, distributed as dist | |
| import torchvision.models as tv | |
| from torch.distributed import barrier | |
| from imaginaire.utils.distributed import is_local_master | |
| def get_lpips_model(): | |
| if dist.is_initialized() and not is_local_master(): | |
| # Make sure only the first process in distributed training downloads the model, and the others use the cache. | |
| barrier() | |
| model = LPIPSNet().cuda() | |
| if dist.is_initialized() and is_local_master(): | |
| # Make sure only the first process in distributed training downloads the model, and the others use the cache. | |
| barrier() | |
| return model | |
| # Learned perceptual network, modified from https://github.com/richzhang/PerceptualSimilarity | |
| def normalize_tensor(in_feat, eps=1e-5): | |
| norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True) + eps) | |
| return in_feat / (norm_factor + eps) | |
| class NetLinLayer(nn.Module): | |
| """ A single linear layer used as placeholder for LPIPS learnt weights """ | |
| def __init__(self, dim): | |
| super(NetLinLayer, self).__init__() | |
| self.weight = nn.Parameter(torch.zeros(1, dim, 1, 1)) | |
| def forward(self, inp): | |
| out = self.weight * inp | |
| return out | |
| class ScalingLayer(nn.Module): | |
| # For rescaling the input to vgg16 | |
| def __init__(self): | |
| super(ScalingLayer, self).__init__() | |
| self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) | |
| self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) | |
| def forward(self, inp): | |
| return (inp - self.shift) / self.scale | |
| class LPIPSNet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = LPNet() | |
| def forward(self, fake_images, fake_images_another, align_corners=True): | |
| features, shape = self._forward_single(fake_images) | |
| features_another, _ = self._forward_single(fake_images_another) | |
| result = 0 | |
| for i, g_feat in enumerate(features): | |
| cur_diff = torch.sum((g_feat - features_another[i]) ** 2, dim=1) / (shape[i] ** 2) | |
| result += cur_diff | |
| return result | |
| def _forward_single(self, images): | |
| return self.model(torch.clamp(images, 0, 1)) | |
| class LPNet(nn.Module): | |
| def __init__(self): | |
| super(LPNet, self).__init__() | |
| self.scaling_layer = ScalingLayer() | |
| self.net = vgg16(pretrained=True, requires_grad=False) | |
| self.L = 5 | |
| dims = [64, 128, 256, 512, 512] | |
| self.lins = nn.ModuleList([NetLinLayer(dims[i]) for i in range(self.L)]) | |
| weights = torch.hub.load_state_dict_from_url( | |
| 'https://github.com/niopeng/CAM-Net/raw/main/code/models/weights/v0.1/vgg.pth' | |
| ) | |
| for i in range(self.L): | |
| self.lins[i].weight.data = torch.sqrt(weights["lin%d.model.1.weight" % i]) | |
| def forward(self, in0, avg=False): | |
| in0 = 2 * in0 - 1 | |
| in0_input = self.scaling_layer(in0) | |
| outs0 = self.net.forward(in0_input) | |
| feats0 = {} | |
| shapes = [] | |
| res = [] | |
| for kk in range(self.L): | |
| feats0[kk] = normalize_tensor(outs0[kk]) | |
| if avg: | |
| res = [self.lins[kk](feats0[kk]).mean([2, 3], keepdim=False) for kk in range(self.L)] | |
| else: | |
| for kk in range(self.L): | |
| cur_res = self.lins[kk](feats0[kk]) | |
| shapes.append(cur_res.shape[-1]) | |
| res.append(cur_res.reshape(cur_res.shape[0], -1)) | |
| return res, shapes | |
| class vgg16(torch.nn.Module): | |
| def __init__(self, requires_grad=False, pretrained=True): | |
| super(vgg16, self).__init__() | |
| vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features | |
| self.slice1 = torch.nn.Sequential() | |
| self.slice2 = torch.nn.Sequential() | |
| self.slice3 = torch.nn.Sequential() | |
| self.slice4 = torch.nn.Sequential() | |
| self.slice5 = torch.nn.Sequential() | |
| self.N_slices = 5 | |
| for x in range(4): | |
| self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(4, 9): | |
| self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(9, 16): | |
| self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(16, 23): | |
| self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(23, 30): | |
| self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
| if not requires_grad: | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x): | |
| h = self.slice1(x) | |
| h_relu1_2 = h | |
| h = self.slice2(h) | |
| h_relu2_2 = h | |
| h = self.slice3(h) | |
| h_relu3_3 = h | |
| h = self.slice4(h) | |
| h_relu4_3 = h | |
| h = self.slice5(h) | |
| h_relu5_3 = h | |
| vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) | |
| out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) | |
| return out | |