Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import Upsample as NearestUpsample | |
| from imaginaire.layers import Conv2dBlock, Res2dBlock | |
| from imaginaire.utils.data import (get_paired_input_image_channel_number, | |
| get_paired_input_label_channel_number) | |
| from imaginaire.utils.distributed import master_only_print as print | |
| class Generator(nn.Module): | |
| r"""Pix2pixHD coarse-to-fine generator constructor. | |
| Args: | |
| gen_cfg (obj): Generator definition part of the yaml config file. | |
| data_cfg (obj): Data definition part of the yaml config file. | |
| """ | |
| def __init__(self, gen_cfg, data_cfg): | |
| super().__init__() | |
| # pix2pixHD has a global generator. | |
| global_gen_cfg = gen_cfg.global_generator | |
| num_filters_global = getattr(global_gen_cfg, 'num_filters', 64) | |
| # Optionally, it can have several local enhancers. They are useful | |
| # for generating high resolution images. | |
| local_gen_cfg = gen_cfg.local_enhancer | |
| self.num_local_enhancers = num_local_enhancers = \ | |
| getattr(local_gen_cfg, 'num_enhancers', 1) | |
| # By default, pix2pixHD using instance normalization. | |
| activation_norm_type = getattr(gen_cfg, 'activation_norm_type', | |
| 'instance') | |
| activation_norm_params = getattr(gen_cfg, 'activation_norm_params', | |
| None) | |
| weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '') | |
| padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect') | |
| base_conv_block = partial(Conv2dBlock, | |
| padding_mode=padding_mode, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type=activation_norm_type, | |
| activation_norm_params=activation_norm_params, | |
| nonlinearity='relu') | |
| base_res_block = partial(Res2dBlock, | |
| padding_mode=padding_mode, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type=activation_norm_type, | |
| activation_norm_params=activation_norm_params, | |
| nonlinearity='relu', order='CNACN') | |
| # Know what is the number of available segmentation labels. | |
| num_input_channels = get_paired_input_label_channel_number(data_cfg) | |
| self.concat_features = False | |
| # Check whether label input contains specific type of data (e.g. | |
| # instance_maps). | |
| self.contain_instance_map = False | |
| if data_cfg.input_labels[-1] == 'instance_maps': | |
| self.contain_instance_map = True | |
| # The feature encoder is only useful when the instance map is provided. | |
| if hasattr(gen_cfg, 'enc') and self.contain_instance_map: | |
| num_feat_channels = getattr(gen_cfg.enc, 'num_feat_channels', 0) | |
| if num_feat_channels > 0: | |
| num_input_channels += num_feat_channels | |
| self.concat_features = True | |
| self.encoder = Encoder(gen_cfg.enc, data_cfg) | |
| # Global generator model. | |
| global_model = GlobalGenerator(global_gen_cfg, data_cfg, | |
| num_input_channels, padding_mode, | |
| base_conv_block, base_res_block) | |
| if num_local_enhancers == 0: | |
| self.global_model = global_model | |
| else: | |
| # Get rid of the last layer. | |
| global_model = global_model.model | |
| global_model = [global_model[i] | |
| for i in range(len(global_model) - 1)] | |
| # global_model = [global_model[i] | |
| # for i in range(len(global_model) - 2)] | |
| self.global_model = nn.Sequential(*global_model) | |
| # Local enhancer model. | |
| for n in range(num_local_enhancers): | |
| # num_filters = num_filters_global // (2 ** n) | |
| num_filters = num_filters_global // (2 ** (n + 1)) | |
| output_img = (n == num_local_enhancers - 1) | |
| setattr(self, 'enhancer_%d' % n, | |
| LocalEnhancer(local_gen_cfg, data_cfg, | |
| num_input_channels, num_filters, | |
| padding_mode, base_conv_block, | |
| base_res_block, output_img)) | |
| self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], | |
| count_include_pad=False) | |
| def forward(self, data, random_style=False): | |
| r"""Coarse-to-fine generator forward. | |
| Args: | |
| data (dict) : Dictionary of input data. | |
| random_style (bool): Always set to false for the pix2pixHD model. | |
| Returns: | |
| output (dict) : Dictionary of output data. | |
| """ | |
| label = data['label'] | |
| output = dict() | |
| if self.concat_features: | |
| features = self.encoder(data['images'], data['instance_maps']) | |
| label = torch.cat([label, features], dim=1) | |
| output['feature_maps'] = features | |
| # Create input pyramid. | |
| input_downsampled = [label] | |
| for i in range(self.num_local_enhancers): | |
| input_downsampled.append(self.downsample(input_downsampled[-1])) | |
| # Output at coarsest level. | |
| x = self.global_model(input_downsampled[-1]) | |
| # Coarse-to-fine: build up one layer at a time. | |
| for n in range(self.num_local_enhancers): | |
| input_n = input_downsampled[self.num_local_enhancers - n - 1] | |
| enhancer = getattr(self, 'enhancer_%d' % n) | |
| x = enhancer(x, input_n) | |
| output['fake_images'] = x | |
| return output | |
| def load_pretrained_network(self, pretrained_dict): | |
| r"""Load a pretrained network.""" | |
| # print(pretrained_dict.keys()) | |
| model_dict = self.state_dict() | |
| print('Pretrained network has fewer layers; The following are ' | |
| 'not initialized:') | |
| not_initialized = set() | |
| for k, v in model_dict.items(): | |
| kp = 'module.' + k.replace('global_model.', 'global_model.model.') | |
| if kp in pretrained_dict and v.size() == pretrained_dict[kp].size(): | |
| model_dict[k] = pretrained_dict[kp] | |
| else: | |
| not_initialized.add('.'.join(k.split('.')[:2])) | |
| print(sorted(not_initialized)) | |
| self.load_state_dict(model_dict) | |
| def inference(self, data, **kwargs): | |
| r"""Generator inference. | |
| Args: | |
| data (dict) : Dictionary of input data. | |
| Returns: | |
| fake_images (tensor): Output fake images. | |
| file_names (str): Data file name. | |
| """ | |
| output = self.forward(data, **kwargs) | |
| return output['fake_images'], data['key']['seg_maps'][0] | |
| class LocalEnhancer(nn.Module): | |
| r"""Local enhancer constructor. These are sub-networks that are useful | |
| when aiming to produce high-resolution outputs. | |
| Args: | |
| gen_cfg (obj): local generator definition part of the yaml config | |
| file. | |
| data_cfg (obj): Data definition part of the yaml config file. | |
| num_input_channels (int): Number of segmentation labels. | |
| num_filters (int): Number of filters for the first layer. | |
| padding_mode (str): zero | reflect | ... | |
| base_conv_block (obj): Conv block with preset attributes. | |
| base_res_block (obj): Residual block with preset attributes. | |
| output_img (bool): Output is image or feature map. | |
| """ | |
| def __init__(self, gen_cfg, data_cfg, num_input_channels, num_filters, | |
| padding_mode, base_conv_block, base_res_block, | |
| output_img=False): | |
| super(LocalEnhancer, self).__init__() | |
| num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 3) | |
| num_img_channels = get_paired_input_image_channel_number(data_cfg) | |
| # Downsample. | |
| model_downsample = \ | |
| [base_conv_block(num_input_channels, num_filters, 7, padding=3), | |
| base_conv_block(num_filters, num_filters * 2, 3, stride=2, | |
| padding=1)] | |
| # Residual blocks. | |
| model_upsample = [] | |
| for i in range(num_res_blocks): | |
| model_upsample += [base_res_block(num_filters * 2, num_filters * 2, | |
| 3, padding=1)] | |
| # Upsample. | |
| model_upsample += \ | |
| [NearestUpsample(scale_factor=2), | |
| base_conv_block(num_filters * 2, num_filters, 3, padding=1)] | |
| # Final convolution. | |
| if output_img: | |
| model_upsample += [Conv2dBlock(num_filters, num_img_channels, 7, | |
| padding=3, padding_mode=padding_mode, | |
| nonlinearity='tanh')] | |
| self.model_downsample = nn.Sequential(*model_downsample) | |
| self.model_upsample = nn.Sequential(*model_upsample) | |
| def forward(self, output_coarse, input_fine): | |
| r"""Local enhancer forward. | |
| Args: | |
| output_coarse (4D tensor) : Coarse output from previous layer. | |
| input_fine (4D tensor) : Fine input from current layer. | |
| Returns: | |
| output (4D tensor) : Refined output. | |
| """ | |
| output = self.model_upsample(self.model_downsample(input_fine) + output_coarse) | |
| return output | |
| class GlobalGenerator(nn.Module): | |
| r"""Coarse generator constructor. This is the main generator in the | |
| pix2pixHD architecture. | |
| Args: | |
| gen_cfg (obj): Generator definition part of the yaml config file. | |
| data_cfg (obj): Data definition part of the yaml config file. | |
| num_input_channels (int): Number of segmentation labels. | |
| padding_mode (str): zero | reflect | ... | |
| base_conv_block (obj): Conv block with preset attributes. | |
| base_res_block (obj): Residual block with preset attributes. | |
| """ | |
| def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode, | |
| base_conv_block, base_res_block): | |
| super(GlobalGenerator, self).__init__() | |
| num_img_channels = get_paired_input_image_channel_number(data_cfg) | |
| num_filters = getattr(gen_cfg, 'num_filters', 64) | |
| num_downsamples = getattr(gen_cfg, 'num_downsamples', 4) | |
| num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9) | |
| # First layer. | |
| model = [base_conv_block(num_input_channels, num_filters, | |
| kernel_size=7, padding=3)] | |
| # Downsample. | |
| for i in range(num_downsamples): | |
| ch = num_filters * (2 ** i) | |
| model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)] | |
| # ResNet blocks. | |
| ch = num_filters * (2 ** num_downsamples) | |
| for i in range(num_res_blocks): | |
| model += [base_res_block(ch, ch, 3, padding=1)] | |
| # Upsample. | |
| num_upsamples = num_downsamples | |
| for i in reversed(range(num_upsamples)): | |
| ch = num_filters * (2 ** i) | |
| model += \ | |
| [NearestUpsample(scale_factor=2), | |
| base_conv_block(ch * 2, ch, 3, padding=1)] | |
| model += [Conv2dBlock(num_filters, num_img_channels, 7, padding=3, | |
| padding_mode=padding_mode, nonlinearity='tanh')] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, input): | |
| r"""Coarse-to-fine generator forward. | |
| Args: | |
| input (4D tensor) : Input semantic representations. | |
| Returns: | |
| output (4D tensor) : Synthesized image by generator. | |
| """ | |
| return self.model(input) | |
| class Encoder(nn.Module): | |
| r"""Encoder for getting region-wise features for style control. | |
| Args: | |
| enc_cfg (obj): Encoder definition part of the yaml config file. | |
| data_cfg (obj): Data definition part of the yaml config file | |
| """ | |
| def __init__(self, enc_cfg, data_cfg): | |
| super(Encoder, self).__init__() | |
| label_nc = get_paired_input_label_channel_number(data_cfg) | |
| feat_nc = enc_cfg.num_feat_channels | |
| n_clusters = getattr(enc_cfg, 'num_clusters', 10) | |
| for i in range(label_nc): | |
| dummy_arr = np.zeros((n_clusters, feat_nc), dtype=np.float32) | |
| self.register_buffer('cluster_%d' % i, | |
| torch.tensor(dummy_arr, dtype=torch.float32)) | |
| num_img_channels = get_paired_input_image_channel_number(data_cfg) | |
| self.num_feat_channels = getattr(enc_cfg, 'num_feat_channels', 3) | |
| num_filters = getattr(enc_cfg, 'num_filters', 64) | |
| num_downsamples = getattr(enc_cfg, 'num_downsamples', 4) | |
| weight_norm_type = getattr(enc_cfg, 'weight_norm_type', 'none') | |
| activation_norm_type = getattr(enc_cfg, 'activation_norm_type', | |
| 'instance') | |
| padding_mode = getattr(enc_cfg, 'padding_mode', 'reflect') | |
| base_conv_block = partial(Conv2dBlock, | |
| padding_mode=padding_mode, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type=activation_norm_type, | |
| nonlinearity='relu') | |
| model = [base_conv_block(num_img_channels, num_filters, 7, padding=3)] | |
| # Downsample. | |
| for i in range(num_downsamples): | |
| ch = num_filters * (2**i) | |
| model += [base_conv_block(ch, ch * 2, 3, stride=2, padding=1)] | |
| # Upsample. | |
| for i in reversed(range(num_downsamples)): | |
| ch = num_filters * (2 ** i) | |
| model += [NearestUpsample(scale_factor=2), | |
| base_conv_block(ch * 2, ch, 3, padding=1)] | |
| model += [Conv2dBlock(num_filters, self.num_feat_channels, 7, | |
| padding=3, padding_mode=padding_mode, | |
| nonlinearity='tanh')] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, input, instance_map): | |
| r"""Extracting region-wise features | |
| Args: | |
| input (4D tensor): Real RGB images. | |
| instance_map (4D tensor): Instance label mask. | |
| Returns: | |
| outputs_mean (4D tensor): Instance-wise average-pooled | |
| feature maps. | |
| """ | |
| outputs = self.model(input) | |
| # Instance-wise average pooling. | |
| outputs_mean = torch.zeros_like(outputs) | |
| # Find all the unique labels in this batch. | |
| inst_list = np.unique(instance_map.cpu().numpy().astype(int)) | |
| for i in inst_list: | |
| for b in range(input.size(0)): | |
| # Find the pixels in this instance map have this instance label. | |
| indices = (instance_map[b:b+1] == int(i)).nonzero() # n x 4 | |
| # Scan through the feature channels. | |
| for j in range(self.num_feat_channels): | |
| output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, | |
| indices[:, 2], indices[:, 3]] | |
| mean_feat = torch.mean(output_ins).expand_as(output_ins) | |
| outputs_mean[indices[:, 0] + b, indices[:, 1] + j, | |
| indices[:, 2], indices[:, 3]] = mean_feat | |
| return outputs_mean | |