Spaces:
Paused
Paused
| import torch | |
| import numpy as np | |
| from torch_utils.ops import bias_act | |
| from torch_utils import misc | |
| def normalize_2nd_moment(x, dim=1, eps=1e-8): | |
| return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() | |
| class FullyConnectedLayer_normal(torch.nn.Module): | |
| def __init__(self, | |
| in_features, # Number of input features. | |
| out_features, # Number of output features. | |
| bias = True, # Apply additive bias before the activation function? | |
| bias_init = 0, # Initial value for the additive bias. | |
| ): | |
| super().__init__() | |
| self.fc = torch.nn.Linear(in_features, out_features, bias=bias) | |
| if bias: | |
| with torch.no_grad(): | |
| self.fc.bias.fill_(bias_init) | |
| def forward(self, x): | |
| output = self.fc(x) | |
| return output | |
| class MappingNetwork_normal(torch.nn.Module): | |
| def __init__(self, | |
| in_features, # Number of input features. | |
| int_dim, | |
| num_layers = 8, # Number of mapping layers. | |
| mapping_normalization = False #2nd normalization | |
| ): | |
| super().__init__() | |
| layers = [torch.nn.Linear(in_features, int_dim), torch.nn.LeakyReLU(0.2)] | |
| for i in range(1, num_layers): | |
| layers.append(torch.nn.Linear(int_dim, int_dim)) | |
| layers.append(torch.nn.LeakyReLU(0.2)) | |
| self.net = torch.nn.Sequential(*layers) | |
| self.normalization = mapping_normalization | |
| def forward(self, x): | |
| if self.normalization: | |
| x = normalize_2nd_moment(x) | |
| output = self.net(x) | |
| return output | |
| class DecodingNetwork(torch.nn.Module): | |
| def __init__(self, | |
| in_features, # Number of input features. | |
| out_dim, | |
| num_layers = 8, # Number of mapping layers. | |
| ): | |
| super().__init__() | |
| layers = [] | |
| for i in range(num_layers-1): | |
| layers.append(torch.nn.Linear(in_features, in_features)) | |
| layers.append(torch.nn.ReLU()) | |
| layers.append(torch.nn.Linear(in_features, out_dim)) | |
| self.net = torch.nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = torch.nn.functional.normalize(x, dim=1) | |
| output = self.net(x) | |
| return output | |
| class FullyConnectedLayer(torch.nn.Module): | |
| def __init__(self, | |
| in_features, # Number of input features. | |
| out_features, # Number of output features. | |
| bias = True, # Apply additive bias before the activation function? | |
| activation = 'linear', # Activation function: 'relu', 'lrelu', etc. | |
| lr_multiplier = 1, # Learning rate multiplier. | |
| bias_init = 0, # Initial value for the additive bias. | |
| ): | |
| super().__init__() | |
| self.activation = activation | |
| self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) | |
| self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None | |
| self.weight_gain = lr_multiplier / np.sqrt(in_features) | |
| self.bias_gain = lr_multiplier | |
| def forward(self, x): | |
| w = self.weight.to(x.dtype) * self.weight_gain | |
| b = self.bias | |
| if b is not None: | |
| b = b.to(x.dtype) | |
| if self.bias_gain != 1: | |
| b = b * self.bias_gain | |
| if self.activation == 'linear' and b is not None: | |
| x = torch.addmm(b.unsqueeze(0), x, w.t()) | |
| else: | |
| x = x.matmul(w.t()) | |
| x = bias_act.bias_act(x, b, act=self.activation) | |
| return x | |
| class MappingNetwork(torch.nn.Module): | |
| def __init__(self, | |
| z_dim, # Input latent (Z) dimensionality, 0 = no latent. | |
| c_dim, # Conditioning label (C) dimensionality, 0 = no label. | |
| w_dim, # Intermediate latent (W) dimensionality. | |
| num_ws, # Number of intermediate latents to output, None = do not broadcast. | |
| num_layers = 8, # Number of mapping layers. | |
| embed_features = None, # Label embedding dimensionality, None = same as w_dim. | |
| layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. | |
| activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. | |
| lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. | |
| w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track. | |
| normalization = None # Normalization input using normalize_2nd_moment | |
| ): | |
| super().__init__() | |
| self.z_dim = z_dim | |
| self.c_dim = c_dim | |
| self.w_dim = w_dim | |
| self.num_ws = num_ws | |
| self.num_layers = num_layers | |
| self.w_avg_beta = w_avg_beta | |
| self.normalization = normalization | |
| if embed_features is None: | |
| embed_features = w_dim | |
| if c_dim == 0: | |
| embed_features = 0 | |
| if layer_features is None: | |
| layer_features = w_dim | |
| features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] | |
| if c_dim > 0: | |
| self.embed = FullyConnectedLayer(c_dim, embed_features) | |
| for idx in range(num_layers): | |
| in_features = features_list[idx] | |
| out_features = features_list[idx + 1] | |
| layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) | |
| setattr(self, f'fc{idx}', layer) | |
| if num_ws is not None and w_avg_beta is not None: | |
| self.register_buffer('w_avg', torch.zeros([w_dim])) | |
| def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): | |
| # Embed, normalize, and concat inputs. | |
| x = None | |
| with torch.autograd.profiler.record_function('input'): | |
| if self.z_dim > 0: | |
| misc.assert_shape(z, [None, self.z_dim]) | |
| if self.normalization: | |
| x = normalize_2nd_moment(z.to(torch.float32)) | |
| else: | |
| x = z | |
| x = z.to(torch.float32) | |
| if self.c_dim > 0: | |
| raise ValueError("This implementation does not need class index") | |
| misc.assert_shape(c, [None, self.c_dim]) | |
| y = normalize_2nd_moment(self.embed(c.to(torch.float32))) | |
| y = self.embed(c.to(torch.float32)) | |
| x = torch.cat([x, y], dim=1) if x is not None else y | |
| # Main layers. | |
| for idx in range(self.num_layers): | |
| layer = getattr(self, f'fc{idx}') | |
| x = layer(x) | |
| # Update moving average of W. | |
| if self.w_avg_beta is not None and self.training and not skip_w_avg_update: | |
| with torch.autograd.profiler.record_function('update_w_avg'): | |
| self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) | |
| # Broadcast. | |
| if self.num_ws is not None: | |
| with torch.autograd.profiler.record_function('broadcast'): | |
| x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) | |
| # Apply truncation. | |
| if truncation_psi != 1: | |
| with torch.autograd.profiler.record_function('truncate'): | |
| assert self.w_avg_beta is not None | |
| if self.num_ws is None or truncation_cutoff is None: | |
| x = self.w_avg.lerp(x, truncation_psi) | |
| else: | |
| x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) | |
| return x | |