Spaces:
Runtime error
Runtime error
| import re | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import init | |
| # Warning: spectral norm could be buggy | |
| # under eval mode and multi-GPU inference | |
| # A workaround is sticking to single-GPU inference and train mode | |
| from torch.nn.utils import spectral_norm | |
| class SPADE(nn.Module): | |
| def __init__(self, config_text, norm_nc, label_nc): | |
| super().__init__() | |
| assert config_text.startswith('spade') | |
| parsed = re.search('spade(\\D+)(\\d)x\\d', config_text) | |
| param_free_norm_type = str(parsed.group(1)) | |
| ks = int(parsed.group(2)) | |
| if param_free_norm_type == 'instance': | |
| self.param_free_norm = nn.InstanceNorm2d(norm_nc) | |
| elif param_free_norm_type == 'syncbatch': | |
| print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead') | |
| self.param_free_norm = nn.InstanceNorm2d(norm_nc) | |
| elif param_free_norm_type == 'batch': | |
| self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) | |
| else: | |
| raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE') | |
| # The dimension of the intermediate embedding space. Yes, hardcoded. | |
| nhidden = 128 if norm_nc > 128 else norm_nc | |
| pw = ks // 2 | |
| self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU()) | |
| self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False) | |
| self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False) | |
| def forward(self, x, segmap): | |
| # Part 1. generate parameter-free normalized activations | |
| normalized = self.param_free_norm(x) | |
| # Part 2. produce scaling and bias conditioned on semantic map | |
| segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') | |
| actv = self.mlp_shared(segmap) | |
| gamma = self.mlp_gamma(actv) | |
| beta = self.mlp_beta(actv) | |
| # apply scale and bias | |
| out = normalized * gamma + beta | |
| return out | |
| class SPADEResnetBlock(nn.Module): | |
| """ | |
| ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that | |
| it takes in the segmentation map as input, learns the skip connection if necessary, | |
| and applies normalization first and then convolution. | |
| This architecture seemed like a standard architecture for unconditional or | |
| class-conditional GAN architecture using residual block. | |
| The code was inspired from https://github.com/LMescheder/GAN_stability. | |
| """ | |
| def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3): | |
| super().__init__() | |
| # Attributes | |
| self.learned_shortcut = (fin != fout) | |
| fmiddle = min(fin, fout) | |
| # create conv layers | |
| self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) | |
| self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) | |
| if self.learned_shortcut: | |
| self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) | |
| # apply spectral norm if specified | |
| if 'spectral' in norm_g: | |
| self.conv_0 = spectral_norm(self.conv_0) | |
| self.conv_1 = spectral_norm(self.conv_1) | |
| if self.learned_shortcut: | |
| self.conv_s = spectral_norm(self.conv_s) | |
| # define normalization layers | |
| spade_config_str = norm_g.replace('spectral', '') | |
| self.norm_0 = SPADE(spade_config_str, fin, semantic_nc) | |
| self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc) | |
| if self.learned_shortcut: | |
| self.norm_s = SPADE(spade_config_str, fin, semantic_nc) | |
| # note the resnet block with SPADE also takes in |seg|, | |
| # the semantic segmentation map as input | |
| def forward(self, x, seg): | |
| x_s = self.shortcut(x, seg) | |
| dx = self.conv_0(self.act(self.norm_0(x, seg))) | |
| dx = self.conv_1(self.act(self.norm_1(dx, seg))) | |
| out = x_s + dx | |
| return out | |
| def shortcut(self, x, seg): | |
| if self.learned_shortcut: | |
| x_s = self.conv_s(self.norm_s(x, seg)) | |
| else: | |
| x_s = x | |
| return x_s | |
| def act(self, x): | |
| return F.leaky_relu(x, 2e-1) | |
| class BaseNetwork(nn.Module): | |
| """ A basis for hifacegan archs with custom initialization """ | |
| def init_weights(self, init_type='normal', gain=0.02): | |
| def init_func(m): | |
| classname = m.__class__.__name__ | |
| if classname.find('BatchNorm2d') != -1: | |
| if hasattr(m, 'weight') and m.weight is not None: | |
| init.normal_(m.weight.data, 1.0, gain) | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| init.constant_(m.bias.data, 0.0) | |
| elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): | |
| if init_type == 'normal': | |
| init.normal_(m.weight.data, 0.0, gain) | |
| elif init_type == 'xavier': | |
| init.xavier_normal_(m.weight.data, gain=gain) | |
| elif init_type == 'xavier_uniform': | |
| init.xavier_uniform_(m.weight.data, gain=1.0) | |
| elif init_type == 'kaiming': | |
| init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
| elif init_type == 'orthogonal': | |
| init.orthogonal_(m.weight.data, gain=gain) | |
| elif init_type == 'none': # uses pytorch's default init method | |
| m.reset_parameters() | |
| else: | |
| raise NotImplementedError(f'initialization method [{init_type}] is not implemented') | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| init.constant_(m.bias.data, 0.0) | |
| self.apply(init_func) | |
| # propagate to children | |
| for m in self.children(): | |
| if hasattr(m, 'init_weights'): | |
| m.init_weights(init_type, gain) | |
| def forward(self, x): | |
| pass | |
| def lip2d(x, logit, kernel=3, stride=2, padding=1): | |
| weight = logit.exp() | |
| return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding) | |
| class SoftGate(nn.Module): | |
| COEFF = 12.0 | |
| def forward(self, x): | |
| return torch.sigmoid(x).mul(self.COEFF) | |
| class SimplifiedLIP(nn.Module): | |
| def __init__(self, channels): | |
| super(SimplifiedLIP, self).__init__() | |
| self.logit = nn.Sequential( | |
| nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True), | |
| SoftGate()) | |
| def init_layer(self): | |
| self.logit[0].weight.data.fill_(0.0) | |
| def forward(self, x): | |
| frac = lip2d(x, self.logit(x)) | |
| return frac | |
| class LIPEncoder(BaseNetwork): | |
| """Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)""" | |
| def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d): | |
| super().__init__() | |
| self.sw = sw | |
| self.sh = sh | |
| self.max_ratio = 16 | |
| # 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold | |
| kw = 3 | |
| pw = (kw - 1) // 2 | |
| model = [ | |
| nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False), | |
| norm_layer(ngf), | |
| nn.ReLU(), | |
| ] | |
| cur_ratio = 1 | |
| for i in range(n_2xdown): | |
| next_ratio = min(cur_ratio * 2, self.max_ratio) | |
| model += [ | |
| SimplifiedLIP(ngf * cur_ratio), | |
| nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw), | |
| norm_layer(ngf * next_ratio), | |
| ] | |
| cur_ratio = next_ratio | |
| if i < n_2xdown - 1: | |
| model += [nn.ReLU(inplace=True)] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| return self.model(x) | |
| def get_nonspade_norm_layer(norm_type='instance'): | |
| # helper function to get # output channels of the previous layer | |
| def get_out_channel(layer): | |
| if hasattr(layer, 'out_channels'): | |
| return getattr(layer, 'out_channels') | |
| return layer.weight.size(0) | |
| # this function will be returned | |
| def add_norm_layer(layer): | |
| nonlocal norm_type | |
| if norm_type.startswith('spectral'): | |
| layer = spectral_norm(layer) | |
| subnorm_type = norm_type[len('spectral'):] | |
| if subnorm_type == 'none' or len(subnorm_type) == 0: | |
| return layer | |
| # remove bias in the previous layer, which is meaningless | |
| # since it has no effect after normalization | |
| if getattr(layer, 'bias', None) is not None: | |
| delattr(layer, 'bias') | |
| layer.register_parameter('bias', None) | |
| if subnorm_type == 'batch': | |
| norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) | |
| elif subnorm_type == 'sync_batch': | |
| print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead') | |
| # norm_layer = SynchronizedBatchNorm2d( | |
| # get_out_channel(layer), affine=True) | |
| norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) | |
| elif subnorm_type == 'instance': | |
| norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) | |
| else: | |
| raise ValueError(f'normalization layer {subnorm_type} is not recognized') | |
| return nn.Sequential(layer, norm_layer) | |
| print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.') | |
| return add_norm_layer | |