Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import warnings | |
| from types import SimpleNamespace | |
| import torch | |
| from torch import nn | |
| from torch.nn import Upsample as NearestUpsample | |
| from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock | |
| from imaginaire.generators.unit import ContentEncoder | |
| class Generator(nn.Module): | |
| r"""Improved MUNIT generator. | |
| Args: | |
| gen_cfg (obj): Generator definition part of the yaml config file. | |
| data_cfg (obj): Data definition part of the yaml config file. | |
| """ | |
| def __init__(self, gen_cfg, data_cfg): | |
| super().__init__() | |
| self.autoencoder_a = AutoEncoder(**vars(gen_cfg)) | |
| self.autoencoder_b = AutoEncoder(**vars(gen_cfg)) | |
| def forward(self, data, random_style=True, image_recon=True, | |
| latent_recon=True, cycle_recon=True, within_latent_recon=False): | |
| r"""In MUNIT's forward pass, it generates a content code and a style | |
| code from images in both domain. It then performs a within-domain | |
| reconstruction step and a cross-domain translation step. | |
| In within-domain reconstruction, it reconstructs an image using the | |
| content and style from the same image and optionally encodes the image | |
| back to the latent space. | |
| In cross-domain translation, it generates an translated image by mixing | |
| the content and style from images in different domains, and optionally | |
| encodes the image back to the latent space. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| - images_a (tensor): Images from domain A. | |
| - images_b (tensor): Images from domain B. | |
| random_style (bool): If ``True``, samples the style code from the | |
| prior distribution, otherwise uses the style code encoded from | |
| the input images in the other domain. | |
| image_recon (bool): If ``True``, also returns reconstructed images. | |
| latent_recon (bool): If ``True``, also returns reconstructed latent | |
| code during cross-domain translation. | |
| cycle_recon (bool): If ``True``, also returns cycle | |
| reconstructed images. | |
| within_latent_recon (bool): If ``True``, also returns reconstructed | |
| latent code during within-domain reconstruction. | |
| """ | |
| images_a = data['images_a'] | |
| images_b = data['images_b'] | |
| net_G_output = dict() | |
| # encode input images into content and style code | |
| content_a, style_a = self.autoencoder_a.encode(images_a) | |
| content_b, style_b = self.autoencoder_b.encode(images_b) | |
| # decode (within domain) | |
| if image_recon: | |
| images_aa = self.autoencoder_a.decode(content_a, style_a) | |
| images_bb = self.autoencoder_b.decode(content_b, style_b) | |
| net_G_output.update(dict(images_aa=images_aa, images_bb=images_bb)) | |
| # decode (cross domain) | |
| if random_style: # use randomly sampled style code | |
| style_a_rand = torch.randn_like(style_a) | |
| style_b_rand = torch.randn_like(style_b) | |
| else: # use style code encoded from the other domain | |
| style_a_rand = style_a | |
| style_b_rand = style_b | |
| images_ba = self.autoencoder_a.decode(content_b, style_a_rand) | |
| images_ab = self.autoencoder_b.decode(content_a, style_b_rand) | |
| # encode translated images into content and style code | |
| if latent_recon or cycle_recon: | |
| content_ba, style_ba = self.autoencoder_a.encode(images_ba) | |
| content_ab, style_ab = self.autoencoder_b.encode(images_ab) | |
| net_G_output.update(dict(content_ba=content_ba, style_ba=style_ba, | |
| content_ab=content_ab, style_ab=style_ab)) | |
| # encode reconstructed images into content and style code | |
| if image_recon and within_latent_recon: | |
| content_aa, style_aa = self.autoencoder_a.encode(images_aa) | |
| content_bb, style_bb = self.autoencoder_b.encode(images_bb) | |
| net_G_output.update(dict(content_aa=content_aa, style_aa=style_aa, | |
| content_bb=content_bb, style_bb=style_bb)) | |
| # cycle reconstruction | |
| if cycle_recon: | |
| images_aba = self.autoencoder_a.decode(content_ab, style_a) | |
| images_bab = self.autoencoder_b.decode(content_ba, style_b) | |
| net_G_output.update( | |
| dict(images_aba=images_aba, images_bab=images_bab)) | |
| # required outputs | |
| net_G_output.update(dict(content_a=content_a, content_b=content_b, | |
| style_a=style_a, style_b=style_b, | |
| style_a_rand=style_a_rand, | |
| style_b_rand=style_b_rand, | |
| images_ba=images_ba, images_ab=images_ab)) | |
| return net_G_output | |
| def inference(self, data, a2b=True, random_style=True): | |
| r"""MUNIT inference. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| - images_a (tensor): Images from domain A. | |
| - images_b (tensor): Images from domain B. | |
| a2b (bool): If ``True``, translates images from domain A to B, | |
| otherwise from B to A. | |
| random_style (bool): If ``True``, samples the style code from the | |
| prior distribution, otherwise uses the style code encoded from | |
| the input images in the other domain. | |
| """ | |
| if a2b: | |
| input_key = 'images_a' | |
| content_encode = self.autoencoder_a.content_encoder | |
| style_encode = self.autoencoder_b.style_encoder | |
| decode = self.autoencoder_b.decode | |
| else: | |
| input_key = 'images_b' | |
| content_encode = self.autoencoder_b.content_encoder | |
| style_encode = self.autoencoder_a.style_encoder | |
| decode = self.autoencoder_a.decode | |
| content_images = data[input_key] | |
| content = content_encode(content_images) | |
| if random_style: | |
| style_channels = self.autoencoder_a.style_channels | |
| style = torch.randn(content.size(0), style_channels, 1, 1, | |
| device=torch.device('cuda')) | |
| file_names = data['key'][input_key]['filename'] | |
| else: | |
| style_key = 'images_b' if a2b else 'images_a' | |
| assert style_key in data.keys(), \ | |
| "{} must be provided when 'random_style' " \ | |
| "is set to False".format(style_key) | |
| style_images = data[style_key] | |
| style = style_encode(style_images) | |
| file_names = \ | |
| [content_name + '_style_' + style_name | |
| for content_name, style_name in | |
| zip(data['key'][input_key]['filename'], | |
| data['key'][style_key]['filename'])] | |
| output_images = decode(content, style) | |
| return output_images, file_names | |
| class AutoEncoder(nn.Module): | |
| r"""Improved MUNIT autoencoder. | |
| Args: | |
| num_filters (int): Base filter numbers. | |
| max_num_filters (int): Maximum number of filters in the encoder. | |
| num_filters_mlp (int): Base filter number in the MLP module. | |
| latent_dim (int): Dimension of the style code. | |
| num_res_blocks (int): Number of residual blocks at the end of the | |
| content encoder. | |
| num_mlp_blocks (int): Number of layers in the MLP module. | |
| num_downsamples_style (int): Number of times we reduce | |
| resolution by 2x2 for the style image. | |
| num_downsamples_content (int): Number of times we reduce | |
| resolution by 2x2 for the content image. | |
| num_image_channels (int): Number of input image channels. | |
| content_norm_type (str): Type of activation normalization in the | |
| content encoder. | |
| style_norm_type (str): Type of activation normalization in the | |
| style encoder. | |
| decoder_norm_type (str): Type of activation normalization in the | |
| decoder. | |
| weight_norm_type (str): Type of weight normalization. | |
| decoder_norm_params (obj): Parameters of activation normalization in the | |
| decoder. If not ``None``, decoder_norm_params.__dict__ will be used | |
| as keyword arguments when initializing activation normalization. | |
| output_nonlinearity (str): Type of nonlinearity before final output, | |
| ``'tanh'`` or ``'none'``. | |
| pre_act (bool): If ``True``, uses pre-activation residual blocks. | |
| apply_noise (bool): If ``True``, injects Gaussian noise in the decoder. | |
| """ | |
| def __init__(self, | |
| num_filters=64, | |
| max_num_filters=256, | |
| num_filters_mlp=256, | |
| latent_dim=8, | |
| num_res_blocks=4, | |
| num_mlp_blocks=2, | |
| num_downsamples_style=4, | |
| num_downsamples_content=2, | |
| num_image_channels=3, | |
| content_norm_type='instance', | |
| style_norm_type='', | |
| decoder_norm_type='instance', | |
| weight_norm_type='', | |
| decoder_norm_params=SimpleNamespace(affine=False), | |
| output_nonlinearity='', | |
| pre_act=False, | |
| apply_noise=False, | |
| **kwargs): | |
| super().__init__() | |
| for key in kwargs: | |
| if key != 'type': | |
| warnings.warn( | |
| "Generator argument '{}' is not used.".format(key)) | |
| self.style_encoder = StyleEncoder(num_downsamples_style, | |
| num_image_channels, | |
| num_filters, | |
| latent_dim, | |
| 'reflect', | |
| style_norm_type, | |
| weight_norm_type, | |
| 'relu') | |
| self.content_encoder = ContentEncoder(num_downsamples_content, | |
| num_res_blocks, | |
| num_image_channels, | |
| num_filters, | |
| max_num_filters, | |
| 'reflect', | |
| content_norm_type, | |
| weight_norm_type, | |
| 'relu', | |
| pre_act) | |
| self.decoder = Decoder(num_downsamples_content, | |
| num_res_blocks, | |
| self.content_encoder.output_dim, | |
| num_image_channels, | |
| num_filters_mlp, | |
| 'reflect', | |
| decoder_norm_type, | |
| decoder_norm_params, | |
| weight_norm_type, | |
| 'relu', | |
| output_nonlinearity, | |
| pre_act, | |
| apply_noise) | |
| self.mlp = MLP(latent_dim, | |
| num_filters_mlp, | |
| num_filters_mlp, | |
| num_mlp_blocks, | |
| 'none', | |
| 'relu') | |
| self.style_channels = latent_dim | |
| def forward(self, images): | |
| r"""Reconstruct an image. | |
| Args: | |
| images (Tensor): Input images. | |
| Returns: | |
| images_recon (Tensor): Reconstructed images. | |
| """ | |
| content, style = self.encode(images) | |
| images_recon = self.decode(content, style) | |
| return images_recon | |
| def encode(self, images): | |
| r"""Encode an image to content and style code. | |
| Args: | |
| images (Tensor): Input images. | |
| Returns: | |
| (tuple): | |
| - content (Tensor): Content code. | |
| - style (Tensor): Style code. | |
| """ | |
| style = self.style_encoder(images) | |
| content = self.content_encoder(images) | |
| return content, style | |
| def decode(self, content, style): | |
| r"""Decode content and style code to an image. | |
| Args: | |
| content (Tensor): Content code. | |
| style (Tensor): Style code. | |
| Returns: | |
| images (Tensor): Output images. | |
| """ | |
| style = self.mlp(style) | |
| images = self.decoder(content, style) | |
| return images | |
| class StyleEncoder(nn.Module): | |
| r"""MUNIT style encoder. | |
| Args: | |
| num_downsamples (int): Number of times we reduce | |
| resolution by 2x2. | |
| num_image_channels (int): Number of input image channels. | |
| num_filters (int): Base filter numbers. | |
| style_channels (int): Dimension of the style code. | |
| padding_mode (string): Type of padding. | |
| activation_norm_type (str): Type of activation normalization. | |
| weight_norm_type (str): Type of weight normalization. | |
| nonlinearity (str): Type of nonlinear activation function. | |
| """ | |
| def __init__(self, num_downsamples, num_image_channels, num_filters, | |
| style_channels, padding_mode, activation_norm_type, | |
| weight_norm_type, nonlinearity): | |
| super().__init__() | |
| conv_params = dict(padding_mode=padding_mode, | |
| activation_norm_type=activation_norm_type, | |
| weight_norm_type=weight_norm_type, | |
| nonlinearity=nonlinearity, | |
| inplace_nonlinearity=True) | |
| model = [] | |
| model += [Conv2dBlock(num_image_channels, num_filters, 7, 1, 3, | |
| **conv_params)] | |
| for i in range(2): | |
| model += [Conv2dBlock(num_filters, 2 * num_filters, 4, 2, 1, | |
| **conv_params)] | |
| num_filters *= 2 | |
| for i in range(num_downsamples - 2): | |
| model += [Conv2dBlock(num_filters, num_filters, 4, 2, 1, | |
| **conv_params)] | |
| model += [nn.AdaptiveAvgPool2d(1)] | |
| model += [nn.Conv2d(num_filters, style_channels, 1, 1, 0)] | |
| self.model = nn.Sequential(*model) | |
| self.output_dim = num_filters | |
| def forward(self, x): | |
| r""" | |
| Args: | |
| x (tensor): Input image. | |
| """ | |
| return self.model(x) | |
| class Decoder(nn.Module): | |
| r"""Improved MUNIT decoder. The network consists of | |
| - $(num_res_blocks) residual blocks. | |
| - $(num_upsamples) residual blocks or convolutional blocks | |
| - output layer. | |
| Args: | |
| num_upsamples (int): Number of times we increase resolution by 2x2. | |
| num_res_blocks (int): Number of residual blocks. | |
| num_filters (int): Base filter numbers. | |
| num_image_channels (int): Number of input image channels. | |
| style_channels (int): Dimension of the style code. | |
| padding_mode (string): Type of padding. | |
| activation_norm_type (str): Type of activation normalization. | |
| activation_norm_params (obj): Parameters of activation normalization. | |
| If not ``None``, decoder_norm_params.__dict__ will be used | |
| as keyword arguments when initializing activation normalization. | |
| weight_norm_type (str): Type of weight normalization. | |
| nonlinearity (str): Type of nonlinear activation function. | |
| output_nonlinearity (str): Type of nonlinearity before final output, | |
| ``'tanh'`` or ``'none'``. | |
| pre_act (bool): If ``True``, uses pre-activation residual blocks. | |
| apply_noise (bool): If ``True``, injects Gaussian noise. | |
| """ | |
| def __init__(self, | |
| num_upsamples, | |
| num_res_blocks, | |
| num_filters, | |
| num_image_channels, | |
| style_channels, | |
| padding_mode, | |
| activation_norm_type, | |
| activation_norm_params, | |
| weight_norm_type, | |
| nonlinearity, | |
| output_nonlinearity, | |
| pre_act=False, | |
| apply_noise=False): | |
| super().__init__() | |
| adain_params = SimpleNamespace( | |
| activation_norm_type=activation_norm_type, | |
| activation_norm_params=activation_norm_params, | |
| cond_dims=style_channels) | |
| conv_params = dict(padding_mode=padding_mode, | |
| nonlinearity=nonlinearity, | |
| inplace_nonlinearity=True, | |
| apply_noise=apply_noise, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type='adaptive', | |
| activation_norm_params=adain_params) | |
| # The order of operations in residual blocks. | |
| order = 'pre_act' if pre_act else 'CNACNA' | |
| # Residual blocks with AdaIN. | |
| self.decoder = nn.ModuleList() | |
| for _ in range(num_res_blocks): | |
| self.decoder += [Res2dBlock(num_filters, num_filters, | |
| **conv_params, | |
| order=order)] | |
| # Convolutional blocks with upsampling. | |
| for i in range(num_upsamples): | |
| self.decoder += [NearestUpsample(scale_factor=2)] | |
| self.decoder += [Conv2dBlock(num_filters, num_filters // 2, | |
| 5, 1, 2, **conv_params)] | |
| num_filters //= 2 | |
| self.decoder += [Conv2dBlock(num_filters, num_image_channels, 7, 1, 3, | |
| nonlinearity=output_nonlinearity, | |
| padding_mode=padding_mode)] | |
| def forward(self, x, style): | |
| r""" | |
| Args: | |
| x (tensor): Content embedding of the content image. | |
| style (tensor): Style embedding of the style image. | |
| """ | |
| for block in self.decoder: | |
| if getattr(block, 'conditional', False): | |
| x = block(x, style) | |
| else: | |
| x = block(x) | |
| return x | |
| class MLP(nn.Module): | |
| r"""The multi-layer perceptron (MLP) that maps Gaussian style code to a | |
| feature vector that is given as the conditional input to AdaIN. | |
| Args: | |
| input_dim (int): Number of channels in the input tensor. | |
| output_dim (int): Number of channels in the output tensor. | |
| latent_dim (int): Number of channels in the latent features. | |
| num_layers (int): Number of layers in the MLP. | |
| norm (str): Type of activation normalization. | |
| nonlinearity (str): Type of nonlinear activation function. | |
| """ | |
| def __init__(self, input_dim, output_dim, latent_dim, num_layers, | |
| norm, nonlinearity): | |
| super().__init__() | |
| model = [] | |
| model += [LinearBlock(input_dim, latent_dim, | |
| activation_norm_type=norm, | |
| nonlinearity=nonlinearity)] | |
| for i in range(num_layers - 2): | |
| model += [LinearBlock(latent_dim, latent_dim, | |
| activation_norm_type=norm, | |
| nonlinearity=nonlinearity)] | |
| model += [LinearBlock(latent_dim, output_dim, | |
| activation_norm_type=norm, | |
| nonlinearity=nonlinearity)] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| r""" | |
| Args: | |
| x (tensor): Input image. | |
| """ | |
| return self.model(x.view(x.size(0), -1)) | |