Spaces:
Running
Running
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| class ModulationConvBlock(nn.Module): | |
| def __init__(self, input_dim, output_dim, kernel_size, stride=1, | |
| padding=0, norm='none', activation='relu', pad_type='zero'): | |
| super(ModulationConvBlock, self).__init__() | |
| self.in_c = input_dim | |
| self.out_c = output_dim | |
| self.ksize = kernel_size | |
| self.stride = 1 | |
| self.padding = kernel_size // 2 | |
| self.eps = 1e-8 | |
| weight_shape = (output_dim, input_dim, kernel_size, kernel_size) | |
| fan_in = kernel_size * kernel_size *input_dim | |
| wscale = 1.0/np.sqrt(fan_in) | |
| self.weight = nn.Parameter(torch.randn(*weight_shape)) | |
| self.wscale = wscale | |
| self.bias = nn.Parameter(torch.zeros(output_dim)) | |
| self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
| self.activate_scale = np.sqrt(2.0) | |
| def forward(self, x, code): | |
| batch,in_channel,height,width = x.shape | |
| weight = self.weight * self.wscale | |
| _weight = weight.view(1, self.ksize, self.ksize, self.in_c, self.out_c) | |
| _weight = _weight * code.view(batch, 1, 1, self.in_c, 1) | |
| # demodulation | |
| _weight_norm = torch.sqrt(torch.sum(_weight ** 2, dim=[1, 2, 3]) + self.eps) | |
| _weight = _weight / _weight_norm.view(batch, 1, 1, 1, self.out_c) | |
| # fused_modulate | |
| x = x.view(1, batch * self.in_c, x.shape[2], x.shape[3]) | |
| weight = _weight.permute(1, 2, 3, 0, 4).reshape( | |
| self.ksize, self.ksize, self.in_c, batch * self.out_c) | |
| # not use_conv2d_transpose | |
| weight = weight.permute(3, 2, 0, 1) | |
| x = F.conv2d(x, | |
| weight=weight, | |
| bias=None, | |
| stride=self.stride, | |
| padding=self.padding, | |
| groups=(batch if True else 1)) | |
| if True:#self.fused_modulate: | |
| x = x.view(batch, self.out_c, height, width) | |
| x = x+self.bias.view(1,-1,1,1) | |
| x = self.activate(x)*self.activate_scale | |
| return x | |
| class AliasConvBlock(nn.Module): | |
| def __init__(self, input_dim, output_dim, kernel_size, stride, | |
| padding=0, norm='none', activation='relu', pad_type='zero'): | |
| super(AliasConvBlock, self).__init__() | |
| self.use_bias = True | |
| # initialize padding | |
| if pad_type == 'reflect': | |
| self.pad = nn.ReflectionPad2d(padding) | |
| elif pad_type == 'replicate': | |
| self.pad = nn.ReplicationPad2d(padding) | |
| elif pad_type == 'zero': | |
| self.pad = nn.ZeroPad2d(padding) | |
| else: | |
| assert 0, "Unsupported padding type: {}".format(pad_type) | |
| # initialize normalization | |
| norm_dim = output_dim | |
| if norm == 'bn': | |
| self.norm = nn.BatchNorm2d(norm_dim) | |
| elif norm == 'in': | |
| # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) | |
| self.norm = nn.InstanceNorm2d(norm_dim) | |
| elif norm == 'ln': | |
| self.norm = LayerNorm(norm_dim) | |
| elif norm == 'adain': | |
| self.norm = AdaptiveInstanceNorm2d(norm_dim) | |
| elif norm == 'none' or norm == 'sn': | |
| self.norm = None | |
| else: | |
| assert 0, "Unsupported normalization: {}".format(norm) | |
| # initialize activation | |
| if activation == 'relu': | |
| self.activation = nn.ReLU(inplace=True) | |
| elif activation == 'lrelu': | |
| self.activation = nn.LeakyReLU(0.2, inplace=True) | |
| elif activation == 'prelu': | |
| self.activation = nn.PReLU() | |
| elif activation == 'selu': | |
| self.activation = nn.SELU(inplace=True) | |
| elif activation == 'tanh': | |
| self.activation = nn.Tanh() | |
| elif activation == 'none': | |
| self.activation = None | |
| else: | |
| assert 0, "Unsupported activation: {}".format(activation) | |
| # initialize convolution | |
| if norm == 'sn': | |
| self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) | |
| else: | |
| self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) | |
| def forward(self, x): | |
| x = self.conv(self.pad(x)) | |
| if self.norm: | |
| x = self.norm(x) | |
| if self.activation: | |
| x = self.activation(x) | |
| return x | |
| class AliasResBlocks(nn.Module): | |
| def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'): | |
| super(AliasResBlocks, self).__init__() | |
| self.model = [] | |
| for i in range(num_blocks): | |
| self.model += [AliasResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] | |
| self.model = nn.Sequential(*self.model) | |
| def forward(self, x): | |
| return self.model(x) | |
| class AliasResBlock(nn.Module): | |
| def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): | |
| super(AliasResBlock, self).__init__() | |
| model = [] | |
| model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] | |
| model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| residual = x | |
| out = self.model(x) | |
| out += residual | |
| return out | |
| ################################################################################## | |
| # Sequential Models | |
| ################################################################################## | |
| class ResBlocks(nn.Module): | |
| def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'): | |
| super(ResBlocks, self).__init__() | |
| self.model = [] | |
| for i in range(num_blocks): | |
| self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] | |
| self.model = nn.Sequential(*self.model) | |
| def forward(self, x): | |
| return self.model(x) | |
| class MLP(nn.Module): | |
| def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): | |
| super(MLP, self).__init__() | |
| self.model = [] | |
| self.model += [linearBlock(input_dim, input_dim, norm=norm, activation=activ)] | |
| self.model += [linearBlock(input_dim, dim, norm=norm, activation=activ)] | |
| for i in range(n_blk - 2): | |
| self.model += [linearBlock(dim, dim, norm=norm, activation=activ)] | |
| self.model += [linearBlock(dim, output_dim, norm='none', activation='none')] # no output activations | |
| self.model = nn.Sequential(*self.model) | |
| # def forward(self, style0, style1, a=0): | |
| # return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3]( | |
| # style1.view(style1.size(0), -1))) | |
| def forward(self, style0, style1=None, a=0): | |
| style1 = style0 | |
| return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3]( | |
| style1.view(style1.size(0), -1))) | |
| ################################################################################## | |
| # Basic Blocks | |
| ################################################################################## | |
| class ResBlock(nn.Module): | |
| def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): | |
| super(ResBlock, self).__init__() | |
| model = [] | |
| model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] | |
| model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| residual = x | |
| out = self.model(x) | |
| out += residual | |
| return out | |
| class ConvBlock(nn.Module): | |
| def __init__(self, input_dim, output_dim, kernel_size, stride, | |
| padding=0, norm='none', activation='relu', pad_type='zero'): | |
| super(ConvBlock, self).__init__() | |
| self.use_bias = True | |
| # initialize padding | |
| if pad_type == 'reflect': | |
| self.pad = nn.ReflectionPad2d(padding) | |
| elif pad_type == 'replicate': | |
| self.pad = nn.ReplicationPad2d(padding) | |
| elif pad_type == 'zero': | |
| self.pad = nn.ZeroPad2d(padding) | |
| else: | |
| assert 0, "Unsupported padding type: {}".format(pad_type) | |
| # initialize normalization | |
| norm_dim = output_dim | |
| if norm == 'bn': | |
| self.norm = nn.BatchNorm2d(norm_dim) | |
| elif norm == 'in': | |
| # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) | |
| self.norm = nn.InstanceNorm2d(norm_dim) | |
| elif norm == 'ln': | |
| self.norm = LayerNorm(norm_dim) | |
| elif norm == 'adain': | |
| self.norm = AdaptiveInstanceNorm2d(norm_dim) | |
| elif norm == 'none' or norm == 'sn': | |
| self.norm = None | |
| else: | |
| assert 0, "Unsupported normalization: {}".format(norm) | |
| # initialize activation | |
| if activation == 'relu': | |
| self.activation = nn.ReLU(inplace=True) | |
| elif activation == 'lrelu': | |
| self.activation = nn.LeakyReLU(0.2, inplace=True) | |
| elif activation == 'prelu': | |
| self.activation = nn.PReLU() | |
| elif activation == 'selu': | |
| self.activation = nn.SELU(inplace=True) | |
| elif activation == 'tanh': | |
| self.activation = nn.Tanh() | |
| elif activation == 'none': | |
| self.activation = None | |
| else: | |
| assert 0, "Unsupported activation: {}".format(activation) | |
| # initialize convolution | |
| if norm == 'sn': | |
| self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) | |
| else: | |
| self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) | |
| def forward(self, x): | |
| x = self.conv(self.pad(x)) | |
| if self.norm: | |
| x = self.norm(x) | |
| if self.activation: | |
| x = self.activation(x) | |
| return x | |
| class linearBlock(nn.Module): | |
| def __init__(self, input_dim, output_dim, norm='none', activation='relu'): | |
| super(linearBlock, self).__init__() | |
| use_bias = True | |
| # initialize fully connected layer | |
| if norm == 'sn': | |
| self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) | |
| else: | |
| self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) | |
| # initialize normalization | |
| norm_dim = output_dim | |
| if norm == 'bn': | |
| self.norm = nn.BatchNorm1d(norm_dim) | |
| elif norm == 'in': | |
| self.norm = nn.InstanceNorm1d(norm_dim) | |
| elif norm == 'ln': | |
| self.norm = LayerNorm(norm_dim) | |
| elif norm == 'none' or norm == 'sn': | |
| self.norm = None | |
| else: | |
| assert 0, "Unsupported normalization: {}".format(norm) | |
| # initialize activation | |
| if activation == 'relu': | |
| self.activation = nn.ReLU(inplace=True) | |
| elif activation == 'lrelu': | |
| self.activation = nn.LeakyReLU(0.2, inplace=True) | |
| elif activation == 'prelu': | |
| self.activation = nn.PReLU() | |
| elif activation == 'selu': | |
| self.activation = nn.SELU(inplace=True) | |
| elif activation == 'tanh': | |
| self.activation = nn.Tanh() | |
| elif activation == 'none': | |
| self.activation = None | |
| else: | |
| assert 0, "Unsupported activation: {}".format(activation) | |
| def forward(self, x): | |
| out = self.fc(x) | |
| if self.norm: | |
| out = self.norm(out) | |
| if self.activation: | |
| out = self.activation(out) | |
| return out | |
| ################################################################################## | |
| # Normalization layers | |
| ################################################################################## | |
| class AdaptiveInstanceNorm2d(nn.Module): | |
| def __init__(self, num_features, eps=1e-5, momentum=0.1): | |
| super(AdaptiveInstanceNorm2d, self).__init__() | |
| self.num_features = num_features | |
| self.eps = eps | |
| self.momentum = momentum | |
| # weight and bias are dynamically assigned | |
| self.weight = None | |
| self.bias = None | |
| # just dummy buffers, not used | |
| self.register_buffer('running_mean', torch.zeros(num_features)) | |
| self.register_buffer('running_var', torch.ones(num_features)) | |
| def forward(self, x): | |
| assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" | |
| b, c = x.size(0), x.size(1) | |
| running_mean = self.running_mean.repeat(b) | |
| running_var = self.running_var.repeat(b) | |
| # Apply instance norm | |
| x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) | |
| out = F.batch_norm( | |
| x_reshaped, running_mean, running_var, self.weight, self.bias, | |
| True, self.momentum, self.eps) | |
| return out.view(b, c, *x.size()[2:]) | |
| def __repr__(self): | |
| return self.__class__.__name__ + '(' + str(self.num_features) + ')' | |
| class LayerNorm(nn.Module): | |
| def __init__(self, num_features, eps=1e-5, affine=True): | |
| super(LayerNorm, self).__init__() | |
| self.num_features = num_features | |
| self.affine = affine | |
| self.eps = eps | |
| if self.affine: | |
| self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) | |
| self.beta = nn.Parameter(torch.zeros(num_features)) | |
| def forward(self, x): | |
| shape = [-1] + [1] * (x.dim() - 1) | |
| # print(x.size()) | |
| if x.size(0) == 1: | |
| # These two lines run much faster in pytorch 0.4 than the two lines listed below. | |
| mean = x.view(-1).mean().view(*shape) | |
| std = x.view(-1).std().view(*shape) | |
| else: | |
| mean = x.view(x.size(0), -1).mean(1).view(*shape) | |
| std = x.view(x.size(0), -1).std(1).view(*shape) | |
| x = (x - mean) / (std + self.eps) | |
| if self.affine: | |
| shape = [1, -1] + [1] * (x.dim() - 2) | |
| x = x * self.gamma.view(*shape) + self.beta.view(*shape) | |
| return x | |
| def l2normalize(v, eps=1e-12): | |
| return v / (v.norm() + eps) | |
| class SpectralNorm(nn.Module): | |
| """ | |
| Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida | |
| and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan | |
| """ | |
| def __init__(self, module, name='weight', power_iterations=1): | |
| super(SpectralNorm, self).__init__() | |
| self.module = module | |
| self.name = name | |
| self.power_iterations = power_iterations | |
| if not self._made_params(): | |
| self._make_params() | |
| def _update_u_v(self): | |
| u = getattr(self.module, self.name + "_u") | |
| v = getattr(self.module, self.name + "_v") | |
| w = getattr(self.module, self.name + "_bar") | |
| height = w.data.shape[0] | |
| for _ in range(self.power_iterations): | |
| v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) | |
| u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) | |
| # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) | |
| sigma = u.dot(w.view(height, -1).mv(v)) | |
| setattr(self.module, self.name, w / sigma.expand_as(w)) | |
| def _made_params(self): | |
| try: | |
| u = getattr(self.module, self.name + "_u") | |
| v = getattr(self.module, self.name + "_v") | |
| w = getattr(self.module, self.name + "_bar") | |
| return True | |
| except AttributeError: | |
| return False | |
| def _make_params(self): | |
| w = getattr(self.module, self.name) | |
| height = w.data.shape[0] | |
| width = w.view(height, -1).data.shape[1] | |
| u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) | |
| v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) | |
| u.data = l2normalize(u.data) | |
| v.data = l2normalize(v.data) | |
| w_bar = nn.Parameter(w.data) | |
| del self.module._parameters[self.name] | |
| self.module.register_parameter(self.name + "_u", u) | |
| self.module.register_parameter(self.name + "_v", v) | |
| self.module.register_parameter(self.name + "_bar", w_bar) | |
| def forward(self, *args): | |
| self._update_u_v() | |
| return self.module.forward(*args) | |