Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| from functools import partial | |
| from types import SimpleNamespace | |
| import torch | |
| from torch import nn | |
| from imaginaire.layers import \ | |
| (Conv2dBlock, LinearBlock, Res2dBlock, UpRes2dBlock) | |
| class Generator(nn.Module): | |
| r"""Generator of the improved FUNIT baseline in the COCO-FUNIT paper. | |
| """ | |
| def __init__(self, gen_cfg, data_cfg): | |
| super().__init__() | |
| self.generator = FUNITTranslator(**vars(gen_cfg)) | |
| def forward(self, data): | |
| r"""In the FUNIT's forward pass, it generates a content embedding and | |
| a style code from the content image, and a style code from the style | |
| image. By mixing the content code and the style code from the content | |
| image, we reconstruct the input image. By mixing the content code and | |
| the style code from the style image, we have a translation output. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| """ | |
| content_a = self.generator.content_encoder(data['images_content']) | |
| style_a = self.generator.style_encoder(data['images_content']) | |
| style_b = self.generator.style_encoder(data['images_style']) | |
| images_trans = self.generator.decode(content_a, style_b) | |
| images_recon = self.generator.decode(content_a, style_a) | |
| net_G_output = dict(images_trans=images_trans, | |
| images_recon=images_recon) | |
| return net_G_output | |
| def inference(self, data, keep_original_size=True): | |
| r"""COCO-FUNIT inference. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| - images_content (tensor): Content images. | |
| - images_style (tensor): Style images. | |
| a2b (bool): If ``True``, translates images from domain A to B, | |
| otherwise from B to A. | |
| keep_original_size (bool): If ``True``, output image is resized | |
| to the input content image size. | |
| """ | |
| content_a = self.generator.content_encoder(data['images_content']) | |
| style_b = self.generator.style_encoder(data['images_style']) | |
| output_images = self.generator.decode(content_a, style_b) | |
| if keep_original_size: | |
| height = data['original_h_w'][0][0] | |
| width = data['original_h_w'][0][1] | |
| # print('( H, W) = ( %d, %d)' % (height, width)) | |
| output_images = torch.nn.functional.interpolate( | |
| output_images, size=[height, width]) | |
| file_names = data['key']['images_content'][0] | |
| return output_images, file_names | |
| class FUNITTranslator(nn.Module): | |
| r""" | |
| Args: | |
| num_filters (int): Base filter numbers. | |
| num_filters_mlp (int): Base filter number in the MLP module. | |
| style_dims (int): Dimension of the style code. | |
| num_res_blocks (int): Number of residual blocks at the end of the | |
| content encoder. | |
| num_mlp_blocks (int): Number of layers in the MLP module. | |
| num_downsamples_content (int): Number of times we reduce | |
| resolution by 2x2 for the content image. | |
| num_downsamples_style (int): Number of times we reduce | |
| resolution by 2x2 for the style image. | |
| num_image_channels (int): Number of input image channels. | |
| weight_norm_type (str): Type of weight normalization. | |
| ``'none'``, ``'spectral'``, or ``'weight'``. | |
| """ | |
| def __init__(self, | |
| num_filters=64, | |
| num_filters_mlp=256, | |
| style_dims=64, | |
| num_res_blocks=2, | |
| num_mlp_blocks=3, | |
| num_downsamples_style=4, | |
| num_downsamples_content=2, | |
| num_image_channels=3, | |
| weight_norm_type='', | |
| **kwargs): | |
| super().__init__() | |
| self.style_encoder = StyleEncoder(num_downsamples_style, | |
| num_image_channels, | |
| num_filters, | |
| style_dims, | |
| 'reflect', | |
| 'none', | |
| weight_norm_type, | |
| 'relu') | |
| self.content_encoder = ContentEncoder(num_downsamples_content, | |
| num_res_blocks, | |
| num_image_channels, | |
| num_filters, | |
| 'reflect', | |
| 'instance', | |
| weight_norm_type, | |
| 'relu') | |
| self.decoder = Decoder(self.content_encoder.output_dim, | |
| num_filters_mlp, | |
| num_image_channels, | |
| num_downsamples_content, | |
| 'reflect', | |
| weight_norm_type, | |
| 'relu') | |
| self.mlp = MLP(style_dims, | |
| num_filters_mlp, | |
| num_filters_mlp, | |
| num_mlp_blocks, | |
| 'none', | |
| 'relu') | |
| def forward(self, images): | |
| r"""Reconstruct the input image by combining the computer content and | |
| style code. | |
| Args: | |
| images (tensor): Input image tensor. | |
| """ | |
| # reconstruct an image | |
| content, style = self.encode(images) | |
| images_recon = self.decode(content, style) | |
| return images_recon | |
| def encode(self, images): | |
| r"""Encoder images to get their content and style codes. | |
| Args: | |
| images (tensor): Input image tensor. | |
| """ | |
| style = self.style_encoder(images) | |
| content = self.content_encoder(images) | |
| return content, style | |
| def decode(self, content, style): | |
| r"""Generate images by combining their content and style codes. | |
| Args: | |
| content (tensor): Content code tensor. | |
| style (tensor): Style code tensor. | |
| """ | |
| style = self.mlp(style) | |
| images = self.decoder(content, style) | |
| return images | |
| class Decoder(nn.Module): | |
| r"""Improved FUNIT decoder. | |
| Args: | |
| num_enc_output_channels (int): Number of content feature channels. | |
| style_channels (int): Dimension of the style code. | |
| num_image_channels (int): Number of image channels. | |
| num_upsamples (int): How many times we are going to apply | |
| upsample residual block. | |
| """ | |
| def __init__(self, | |
| num_enc_output_channels, | |
| style_channels, | |
| num_image_channels=3, | |
| num_upsamples=4, | |
| padding_type='reflect', | |
| weight_norm_type='none', | |
| nonlinearity='relu'): | |
| super(Decoder, self).__init__() | |
| adain_params = SimpleNamespace( | |
| activation_norm_type='instance', | |
| activation_norm_params=SimpleNamespace(affine=False), | |
| cond_dims=style_channels) | |
| base_res_block = partial(Res2dBlock, | |
| kernel_size=3, | |
| padding=1, | |
| padding_mode=padding_type, | |
| nonlinearity=nonlinearity, | |
| activation_norm_type='adaptive', | |
| activation_norm_params=adain_params, | |
| weight_norm_type=weight_norm_type, | |
| learn_shortcut=False) | |
| base_up_res_block = partial(UpRes2dBlock, | |
| kernel_size=5, | |
| padding=2, | |
| padding_mode=padding_type, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type='adaptive', | |
| activation_norm_params=adain_params, | |
| skip_activation_norm='instance', | |
| skip_nonlinearity=nonlinearity, | |
| nonlinearity=nonlinearity, | |
| hidden_channels_equal_out_channels=True, | |
| learn_shortcut=True) | |
| dims = num_enc_output_channels | |
| # Residual blocks with AdaIN. | |
| self.decoder = nn.ModuleList() | |
| self.decoder += [base_res_block(dims, dims)] | |
| self.decoder += [base_res_block(dims, dims)] | |
| for _ in range(num_upsamples): | |
| self.decoder += [base_up_res_block(dims, dims // 2)] | |
| dims = dims // 2 | |
| self.decoder += [Conv2dBlock(dims, | |
| num_image_channels, | |
| kernel_size=7, | |
| stride=1, | |
| padding=3, | |
| padding_mode='reflect', | |
| nonlinearity='tanh')] | |
| def forward(self, x, style): | |
| r""" | |
| Args: | |
| x (tensor): Content embedding of the content image. | |
| style (tensor): Style embedding of the style image. | |
| """ | |
| for block in self.decoder: | |
| if getattr(block, 'conditional', False): | |
| x = block(x, style) | |
| else: | |
| x = block(x) | |
| return x | |
| class StyleEncoder(nn.Module): | |
| r"""Improved FUNIT Style Encoder. This is basically the same as the | |
| original FUNIT Style Encoder. | |
| Args: | |
| num_downsamples (int): Number of times we reduce resolution by | |
| 2x2. | |
| image_channels (int): Number of input image channels. | |
| num_filters (int): Base filter number. | |
| style_channels (int): Style code dimension. | |
| padding_mode (str): Padding mode. | |
| activation_norm_type (str): Type of activation normalization. | |
| weight_norm_type (str): Type of weight normalization. | |
| ``'none'``, ``'spectral'``, or ``'weight'``. | |
| nonlinearity (str): Nonlinearity. | |
| """ | |
| def __init__(self, | |
| num_downsamples, | |
| image_channels, | |
| num_filters, | |
| style_channels, | |
| padding_mode, | |
| activation_norm_type, | |
| weight_norm_type, | |
| nonlinearity): | |
| super().__init__() | |
| conv_params = dict(padding_mode=padding_mode, | |
| activation_norm_type=activation_norm_type, | |
| weight_norm_type=weight_norm_type, | |
| nonlinearity=nonlinearity, | |
| inplace_nonlinearity=True) | |
| model = [] | |
| model += [Conv2dBlock(image_channels, num_filters, 7, 1, 3, | |
| **conv_params)] | |
| for i in range(2): | |
| model += [Conv2dBlock(num_filters, 2 * num_filters, 4, 2, 1, | |
| **conv_params)] | |
| num_filters *= 2 | |
| for i in range(num_downsamples - 2): | |
| model += [Conv2dBlock(num_filters, num_filters, 4, 2, 1, | |
| **conv_params)] | |
| model += [nn.AdaptiveAvgPool2d(1)] | |
| model += [nn.Conv2d(num_filters, style_channels, 1, 1, 0)] | |
| self.model = nn.Sequential(*model) | |
| self.output_dim = num_filters | |
| def forward(self, x): | |
| r""" | |
| Args: | |
| x (tensor): Input image. | |
| """ | |
| return self.model(x) | |
| class ContentEncoder(nn.Module): | |
| r"""Improved FUNIT Content Encoder. This is basically the same as the | |
| original FUNIT content encoder. | |
| Args: | |
| num_downsamples (int): Number of times we reduce resolution by | |
| 2x2. | |
| num_res_blocks (int): Number of times we append residual block | |
| after all the downsampling modules. | |
| image_channels (int): Number of input image channels. | |
| num_filters (int): Base filter number. | |
| padding_mode (str): Padding mode | |
| activation_norm_type (str): Type of activation normalization. | |
| weight_norm_type (str): Type of weight normalization. | |
| ``'none'``, ``'spectral'``, or ``'weight'``. | |
| nonlinearity (str): Nonlinearity. | |
| """ | |
| def __init__(self, | |
| num_downsamples, | |
| num_res_blocks, | |
| image_channels, | |
| num_filters, | |
| padding_mode, | |
| activation_norm_type, | |
| weight_norm_type, | |
| nonlinearity): | |
| super().__init__() | |
| conv_params = dict(padding_mode=padding_mode, | |
| activation_norm_type=activation_norm_type, | |
| weight_norm_type=weight_norm_type, | |
| nonlinearity=nonlinearity, | |
| inplace_nonlinearity=True, | |
| order='CNACNA') | |
| model = [] | |
| model += [Conv2dBlock(image_channels, num_filters, 7, 1, 3, | |
| **conv_params)] | |
| dims = num_filters | |
| for i in range(num_downsamples): | |
| model += [Conv2dBlock(dims, dims * 2, 4, 2, 1, **conv_params)] | |
| dims *= 2 | |
| for _ in range(num_res_blocks): | |
| model += [Res2dBlock(dims, dims, learn_shortcut=False, **conv_params)] | |
| self.model = nn.Sequential(*model) | |
| self.output_dim = dims | |
| def forward(self, x): | |
| r""" | |
| Args: | |
| x (tensor): Input image. | |
| """ | |
| return self.model(x) | |
| class MLP(nn.Module): | |
| r"""Improved FUNIT style decoder. | |
| Args: | |
| input_dim (int): Input dimension (style code dimension). | |
| output_dim (int): Output dimension (to be fed into the AdaIN | |
| layer). | |
| latent_dim (int): Latent dimension. | |
| num_layers (int): Number of layers in the MLP. | |
| activation_norm_type (str): Activation type. | |
| nonlinearity (str): Nonlinearity type. | |
| """ | |
| def __init__(self, | |
| input_dim, | |
| output_dim, | |
| latent_dim, | |
| num_layers, | |
| activation_norm_type, | |
| nonlinearity): | |
| super().__init__() | |
| model = [] | |
| model += [LinearBlock(input_dim, latent_dim, | |
| activation_norm_type=activation_norm_type, | |
| nonlinearity=nonlinearity)] | |
| # changed from num_layers - 2 to num_layers - 3. | |
| for i in range(num_layers - 3): | |
| model += [LinearBlock(latent_dim, latent_dim, | |
| activation_norm_type=activation_norm_type, | |
| nonlinearity=nonlinearity)] | |
| model += [LinearBlock(latent_dim, output_dim, | |
| activation_norm_type=activation_norm_type, | |
| nonlinearity=nonlinearity)] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| r""" | |
| Args: | |
| x (tensor): Input tensor. | |
| """ | |
| return self.model(x.view(x.size(0), -1)) | |