Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| from types import SimpleNamespace | |
| import torch | |
| from torch import nn | |
| from .misc import ApplyNoise | |
| from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur | |
| class ViT2dBlock(nn.Module): | |
| r"""An abstract wrapper class that wraps a torch convolution or linear layer | |
| with normalization and nonlinearity. | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, | |
| padding, dilation, groups, bias, padding_mode, | |
| weight_norm_type, weight_norm_params, | |
| activation_norm_type, activation_norm_params, | |
| nonlinearity, inplace_nonlinearity, | |
| apply_noise, blur, order, input_dim, clamp, | |
| blur_kernel=(1, 3, 3, 1), output_scale=None, | |
| init_gain=1.0): | |
| super().__init__() | |
| from .nonlinearity import get_nonlinearity_layer | |
| from .weight_norm import get_weight_norm_layer | |
| from .activation_norm import get_activation_norm_layer | |
| self.weight_norm_type = weight_norm_type | |
| self.stride = stride | |
| self.clamp = clamp | |
| self.init_gain = init_gain | |
| # Nonlinearity layer. | |
| if 'fused' in nonlinearity: | |
| # Fusing nonlinearity with bias. | |
| lr_mul = getattr(weight_norm_params, 'lr_mul', 1) | |
| conv_before_nonlinearity = order.find('C') < order.find('A') | |
| if conv_before_nonlinearity: | |
| assert bias | |
| bias = False | |
| channel = out_channels if conv_before_nonlinearity else in_channels | |
| nonlinearity_layer = get_nonlinearity_layer( | |
| nonlinearity, inplace=inplace_nonlinearity, | |
| num_channels=channel, lr_mul=lr_mul) | |
| else: | |
| nonlinearity_layer = get_nonlinearity_layer( | |
| nonlinearity, inplace=inplace_nonlinearity) | |
| # Noise injection layer. | |
| if apply_noise: | |
| order = order.replace('C', 'CG') | |
| noise_layer = ApplyNoise() | |
| else: | |
| noise_layer = None | |
| # Convolutional layer. | |
| if blur: | |
| if stride == 2: | |
| # Blur - Conv - Noise - Activate | |
| p = (len(blur_kernel) - 2) + (kernel_size - 1) | |
| pad0, pad1 = (p + 1) // 2, p // 2 | |
| padding = 0 | |
| blur_layer = Blur( | |
| blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode | |
| ) | |
| order = order.replace('C', 'BC') | |
| elif stride == 0.5: | |
| # Conv - Blur - Noise - Activate | |
| padding = 0 | |
| p = (len(blur_kernel) - 2) - (kernel_size - 1) | |
| pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1 | |
| blur_layer = Blur( | |
| blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode | |
| ) | |
| order = order.replace('C', 'CB') | |
| elif stride == 1: | |
| # No blur for now | |
| blur_layer = nn.Identity() | |
| else: | |
| raise NotImplementedError | |
| else: | |
| blur_layer = nn.Identity() | |
| if weight_norm_params is None: | |
| weight_norm_params = SimpleNamespace() | |
| weight_norm = get_weight_norm_layer( | |
| weight_norm_type, **vars(weight_norm_params)) | |
| conv_layer = weight_norm(self._get_conv_layer( | |
| in_channels, out_channels, kernel_size, stride, padding, dilation, | |
| groups, bias, padding_mode, input_dim)) | |
| # Normalization layer. | |
| conv_before_norm = order.find('C') < order.find('N') | |
| norm_channels = out_channels if conv_before_norm else in_channels | |
| if activation_norm_params is None: | |
| activation_norm_params = SimpleNamespace() | |
| activation_norm_layer = get_activation_norm_layer( | |
| norm_channels, | |
| activation_norm_type, | |
| input_dim, | |
| **vars(activation_norm_params)) | |
| # Mapping from operation names to layers. | |
| mappings = {'C': {'conv': conv_layer}, | |
| 'N': {'norm': activation_norm_layer}, | |
| 'A': {'nonlinearity': nonlinearity_layer}} | |
| mappings.update({'B': {'blur': blur_layer}}) | |
| mappings.update({'G': {'noise': noise_layer}}) | |
| # All layers in order. | |
| self.layers = nn.ModuleDict() | |
| for op in order: | |
| if list(mappings[op].values())[0] is not None: | |
| self.layers.update(mappings[op]) | |
| # Whether this block expects conditional inputs. | |
| self.conditional = \ | |
| getattr(conv_layer, 'conditional', False) or \ | |
| getattr(activation_norm_layer, 'conditional', False) | |
| if output_scale is not None: | |
| self.output_scale = nn.Parameter(torch.tensor(output_scale)) | |
| else: | |
| self.register_parameter("output_scale", None) | |
| def forward(self, x, *cond_inputs, **kw_cond_inputs): | |
| r""" | |
| Args: | |
| x (tensor): Input tensor. | |
| cond_inputs (list of tensors) : Conditional input tensors. | |
| kw_cond_inputs (dict) : Keyword conditional inputs. | |
| """ | |
| for key, layer in self.layers.items(): | |
| if getattr(layer, 'conditional', False): | |
| # Layers that require conditional inputs. | |
| x = layer(x, *cond_inputs, **kw_cond_inputs) | |
| else: | |
| x = layer(x) | |
| if self.clamp is not None and isinstance(layer, nn.Conv2d): | |
| x.clamp_(max=self.clamp) | |
| if key == 'conv': | |
| if self.output_scale is not None: | |
| x = x * self.output_scale | |
| return x | |
| def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, | |
| padding, dilation, groups, bias, padding_mode, | |
| input_dim): | |
| # Returns the convolutional layer. | |
| if input_dim == 0: | |
| layer = nn.Linear(in_channels, out_channels, bias) | |
| else: | |
| if stride < 1: # Fractionally-strided convolution. | |
| padding_mode = 'zeros' | |
| assert padding == 0 | |
| layer_type = getattr(nn, f'ConvTranspose{input_dim}d') | |
| stride = round(1 / stride) | |
| else: | |
| layer_type = getattr(nn, f'Conv{input_dim}d') | |
| layer = layer_type( | |
| in_channels, out_channels, kernel_size, stride, padding, | |
| dilation=dilation, groups=groups, bias=bias, | |
| padding_mode=padding_mode | |
| ) | |
| return layer | |
| def __repr__(self): | |
| main_str = self._get_name() + '(' | |
| child_lines = [] | |
| for name, layer in self.layers.items(): | |
| mod_str = repr(layer) | |
| if name == 'conv' and self.weight_norm_type != 'none' and \ | |
| self.weight_norm_type != '': | |
| mod_str = mod_str[:-1] + \ | |
| ', weight_norm={}'.format(self.weight_norm_type) + ')' | |
| if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1: | |
| mod_str = mod_str[:-1] + \ | |
| ', lr_mul={}'.format(layer.base_lr_mul) + ')' | |
| mod_str = self._addindent(mod_str, 2) | |
| child_lines.append(mod_str) | |
| if len(child_lines) == 1: | |
| main_str += child_lines[0] | |
| else: | |
| main_str += '\n ' + '\n '.join(child_lines) + '\n' | |
| main_str += ')' | |
| return main_str | |
| def _addindent(s_, numSpaces): | |
| s = s_.split('\n') | |
| # don't do anything for single-line stuff | |
| if len(s) == 1: | |
| return s_ | |
| first = s.pop(0) | |
| s = [(numSpaces * ' ') + line for line in s] | |
| s = '\n'.join(s) | |
| s = first + '\n' + s | |
| return s | |