Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import functools | |
| import math | |
| import types | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import Upsample as NearestUpsample | |
| from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock | |
| from imaginaire.utils.data import (get_crop_h_w, | |
| get_paired_input_image_channel_number, | |
| get_paired_input_label_channel_number) | |
| from imaginaire.utils.distributed import master_only_print as print | |
| class Generator(nn.Module): | |
| r"""SPADE generator constructor. | |
| Args: | |
| gen_cfg (obj): Generator definition part of the yaml config file. | |
| data_cfg (obj): Data definition part of the yaml config file. | |
| """ | |
| def __init__(self, gen_cfg, data_cfg): | |
| super(Generator, self).__init__() | |
| print('SPADE generator initialization.') | |
| # We assume the first datum is the ground truth image. | |
| image_channels = getattr(gen_cfg, 'image_channels', None) | |
| if image_channels is None: | |
| image_channels = get_paired_input_image_channel_number(data_cfg) | |
| num_labels = getattr(gen_cfg, 'num_labels', None) | |
| if num_labels is None: | |
| # Calculate number of channels in the input label when not specified. | |
| num_labels = get_paired_input_label_channel_number(data_cfg) | |
| crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations) | |
| # Build the generator | |
| out_image_small_side_size = crop_w if crop_w < crop_h else crop_h | |
| num_filters = getattr(gen_cfg, 'num_filters', 128) | |
| kernel_size = getattr(gen_cfg, 'kernel_size', 3) | |
| weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral') | |
| cond_dims = 0 | |
| # Check whether we use the style code. | |
| style_dims = getattr(gen_cfg, 'style_dims', None) | |
| self.style_dims = style_dims | |
| if style_dims is not None: | |
| print('\tStyle code dimensions: %d' % style_dims) | |
| cond_dims += style_dims | |
| self.use_style = True | |
| else: | |
| self.use_style = False | |
| # Check whether we use the attribute code. | |
| if hasattr(gen_cfg, 'attribute_dims'): | |
| self.use_attribute = True | |
| self.attribute_dims = gen_cfg.attribute_dims | |
| cond_dims += gen_cfg.attribute_dims | |
| else: | |
| self.use_attribute = False | |
| if not self.use_style and not self.use_attribute: | |
| self.use_style_encoder = False | |
| else: | |
| self.use_style_encoder = True | |
| print('\tBase filter number: %d' % num_filters) | |
| print('\tConvolution kernel size: %d' % kernel_size) | |
| print('\tWeight norm type: %s' % weight_norm_type) | |
| skip_activation_norm = \ | |
| getattr(gen_cfg, 'skip_activation_norm', True) | |
| activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None) | |
| if activation_norm_params is None: | |
| activation_norm_params = types.SimpleNamespace() | |
| if not hasattr(activation_norm_params, 'num_filters'): | |
| setattr(activation_norm_params, 'num_filters', 128) | |
| if not hasattr(activation_norm_params, 'kernel_size'): | |
| setattr(activation_norm_params, 'kernel_size', 3) | |
| if not hasattr(activation_norm_params, 'activation_norm_type'): | |
| setattr(activation_norm_params, 'activation_norm_type', 'sync_batch') | |
| if not hasattr(activation_norm_params, 'separate_projection'): | |
| setattr(activation_norm_params, 'separate_projection', False) | |
| if not hasattr(activation_norm_params, 'activation_norm_params'): | |
| activation_norm_params.activation_norm_params = types.SimpleNamespace() | |
| activation_norm_params.activation_norm_params.affine = True | |
| setattr(activation_norm_params, 'cond_dims', num_labels) | |
| if not hasattr(activation_norm_params, 'weight_norm_type'): | |
| setattr(activation_norm_params, 'weight_norm_type', weight_norm_type) | |
| global_adaptive_norm_type = getattr(gen_cfg, 'global_adaptive_norm_type', 'sync_batch') | |
| use_posenc_in_input_layer = getattr(gen_cfg, 'use_posenc_in_input_layer', True) | |
| output_multiplier = getattr(gen_cfg, 'output_multiplier', 1.0) | |
| print(activation_norm_params) | |
| self.spade_generator = SPADEGenerator(num_labels, | |
| out_image_small_side_size, | |
| image_channels, | |
| num_filters, | |
| kernel_size, | |
| cond_dims, | |
| activation_norm_params, | |
| weight_norm_type, | |
| global_adaptive_norm_type, | |
| skip_activation_norm, | |
| use_posenc_in_input_layer, | |
| self.use_style_encoder, | |
| output_multiplier) | |
| if self.use_style: | |
| # Build the encoder. | |
| style_enc_cfg = getattr(gen_cfg, 'style_enc', None) | |
| if style_enc_cfg is None: | |
| style_enc_cfg = types.SimpleNamespace() | |
| if not hasattr(style_enc_cfg, 'num_filters'): | |
| setattr(style_enc_cfg, 'num_filters', 128) | |
| if not hasattr(style_enc_cfg, 'kernel_size'): | |
| setattr(style_enc_cfg, 'kernel_size', 3) | |
| if not hasattr(style_enc_cfg, 'weight_norm_type'): | |
| setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type) | |
| setattr(style_enc_cfg, 'input_image_channels', image_channels) | |
| setattr(style_enc_cfg, 'style_dims', style_dims) | |
| self.style_encoder = StyleEncoder(style_enc_cfg) | |
| self.z = None | |
| print('Done with the SPADE generator initialization.') | |
| def forward(self, data, random_style=False): | |
| r"""SPADE Generator forward. | |
| Args: | |
| data (dict): | |
| - images (N x C1 x H x W tensor) : Ground truth images | |
| - label (N x C2 x H x W tensor) : Semantic representations | |
| - z (N x style_dims tensor): Gaussian random noise | |
| - random_style (bool): Whether to sample a random style vector. | |
| Returns: | |
| (dict): | |
| - fake_images (N x 3 x H x W tensor): fake images | |
| - mu (N x C1 tensor): mean vectors | |
| - logvar (N x C1 tensor): log-variance vectors | |
| """ | |
| if self.use_style_encoder: | |
| if random_style: | |
| bs = data['label'].size(0) | |
| z = torch.randn( | |
| bs, self.style_dims, dtype=torch.float32).cuda() | |
| if (data['label'].dtype == | |
| data['label'].dtype == torch.float16): | |
| z = z.half() | |
| mu = None | |
| logvar = None | |
| else: | |
| mu, logvar, z = self.style_encoder(data['images']) | |
| if self.use_attribute: | |
| data['z'] = torch.cat((z, data['attributes'].squeeze(1)), dim=1) | |
| else: | |
| data['z'] = z | |
| output = self.spade_generator(data) | |
| if self.use_style_encoder: | |
| output['mu'] = mu | |
| output['logvar'] = logvar | |
| return output | |
| def inference(self, | |
| data, | |
| random_style=False, | |
| use_fixed_random_style=False, | |
| keep_original_size=False): | |
| r"""Compute results images for a batch of input data and save the | |
| results in the specified folder. | |
| Args: | |
| data (dict): | |
| - images (N x C1 x H x W tensor) : Ground truth images | |
| - label (N x C2 x H x W tensor) : Semantic representations | |
| - z (N x style_dims tensor): Gaussian random noise | |
| random_style (bool): Whether to sample a random style vector. | |
| use_fixed_random_style (bool): Sample random style once and use it | |
| for all the remaining inference. | |
| keep_original_size (bool): Keep original size of the input. | |
| Returns: | |
| (dict): | |
| - fake_images (N x 3 x H x W tensor): fake images | |
| - mu (N x C1 tensor): mean vectors | |
| - logvar (N x C1 tensor): log-variance vectors | |
| """ | |
| self.eval() | |
| self.spade_generator.eval() | |
| if self.use_style_encoder: | |
| if random_style and self.use_style_encoder: | |
| if self.z is None or not use_fixed_random_style: | |
| bs = data['label'].size(0) | |
| z = torch.randn( | |
| bs, self.style_dims, dtype=torch.float32).to('cuda') | |
| if (data['label'].dtype == | |
| data['label'].dtype == | |
| torch.float16): | |
| z = z.half() | |
| self.z = z | |
| else: | |
| z = self.z | |
| else: | |
| mu, logvar, z = self.style_encoder(data['images']) | |
| data['z'] = z | |
| output = self.spade_generator(data) | |
| output_images = output['fake_images'] | |
| if keep_original_size: | |
| height = data['original_h_w'][0][0] | |
| width = data['original_h_w'][0][1] | |
| output_images = torch.nn.functional.interpolate( | |
| output_images, size=[height, width]) | |
| for key in data['key'].keys(): | |
| if 'segmaps' in key or 'seg_maps' in key: | |
| file_names = data['key'][key][0] | |
| break | |
| for key in data['key'].keys(): | |
| if 'edgemaps' in key or 'edge_maps' in key: | |
| file_names = data['key'][key][0] | |
| break | |
| return output_images, file_names | |
| class SPADEGenerator(nn.Module): | |
| r"""SPADE Image Generator constructor. | |
| Args: | |
| num_labels (int): Number of different labels. | |
| out_image_small_side_size (int): min(width, height) | |
| image_channels (int): Num. of channels of the output image. | |
| num_filters (int): Base filter numbers. | |
| kernel_size (int): Convolution kernel size. | |
| style_dims (int): Dimensions of the style code. | |
| activation_norm_params (obj): Spatially adaptive normalization param. | |
| weight_norm_type (str): Type of weight normalization. | |
| ``'none'``, ``'spectral'``, or ``'weight'``. | |
| global_adaptive_norm_type (str): Type of normalization in SPADE. | |
| skip_activation_norm (bool): If ``True``, applies activation norm to the | |
| shortcut connection in residual blocks. | |
| use_style_encoder (bool): Whether to use global adaptive norm | |
| like conditional batch norm or adaptive instance norm. | |
| output_multiplier (float): A positive number multiplied to the output | |
| """ | |
| def __init__(self, | |
| num_labels, | |
| out_image_small_side_size, | |
| image_channels, | |
| num_filters, | |
| kernel_size, | |
| style_dims, | |
| activation_norm_params, | |
| weight_norm_type, | |
| global_adaptive_norm_type, | |
| skip_activation_norm, | |
| use_posenc_in_input_layer, | |
| use_style_encoder, | |
| output_multiplier): | |
| super(SPADEGenerator, self).__init__() | |
| self.output_multiplier = output_multiplier | |
| self.use_style_encoder = use_style_encoder | |
| self.use_posenc_in_input_layer = use_posenc_in_input_layer | |
| self.out_image_small_side_size = out_image_small_side_size | |
| self.num_filters = num_filters | |
| padding = int(np.ceil((kernel_size - 1.0) / 2)) | |
| nonlinearity = 'leakyrelu' | |
| activation_norm_type = 'spatially_adaptive' | |
| base_res2d_block = \ | |
| functools.partial(Res2dBlock, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| bias=[True, True, False], | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type=activation_norm_type, | |
| activation_norm_params=activation_norm_params, | |
| skip_activation_norm=skip_activation_norm, | |
| nonlinearity=nonlinearity, | |
| order='NACNAC') | |
| if self.use_style_encoder: | |
| self.fc_0 = LinearBlock(style_dims, 2 * style_dims, | |
| weight_norm_type=weight_norm_type, | |
| nonlinearity='relu', | |
| order='CAN') | |
| self.fc_1 = LinearBlock(2 * style_dims, 2 * style_dims, | |
| weight_norm_type=weight_norm_type, | |
| nonlinearity='relu', | |
| order='CAN') | |
| adaptive_norm_params = types.SimpleNamespace() | |
| if not hasattr(adaptive_norm_params, 'cond_dims'): | |
| setattr(adaptive_norm_params, 'cond_dims', 2 * style_dims) | |
| if not hasattr(adaptive_norm_params, 'activation_norm_type'): | |
| setattr(adaptive_norm_params, 'activation_norm_type', global_adaptive_norm_type) | |
| if not hasattr(adaptive_norm_params, 'weight_norm_type'): | |
| setattr(adaptive_norm_params, 'weight_norm_type', activation_norm_params.weight_norm_type) | |
| if not hasattr(adaptive_norm_params, 'separate_projection'): | |
| setattr(adaptive_norm_params, 'separate_projection', activation_norm_params.separate_projection) | |
| adaptive_norm_params.activation_norm_params = types.SimpleNamespace() | |
| setattr(adaptive_norm_params.activation_norm_params, 'affine', | |
| activation_norm_params.activation_norm_params.affine) | |
| base_cbn2d_block = \ | |
| functools.partial(Conv2dBlock, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=padding, | |
| bias=True, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type='adaptive', | |
| activation_norm_params=adaptive_norm_params, | |
| nonlinearity=nonlinearity, | |
| order='NAC') | |
| else: | |
| base_conv2d_block = \ | |
| functools.partial(Conv2dBlock, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=padding, | |
| bias=True, | |
| weight_norm_type=weight_norm_type, | |
| nonlinearity=nonlinearity, | |
| order='NAC') | |
| in_num_labels = num_labels | |
| in_num_labels += 2 if self.use_posenc_in_input_layer else 0 | |
| self.head_0 = Conv2dBlock(in_num_labels, 8 * num_filters, | |
| kernel_size=kernel_size, stride=1, | |
| padding=padding, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type='none', | |
| nonlinearity=nonlinearity) | |
| if self.use_style_encoder: | |
| self.cbn_head_0 = base_cbn2d_block( | |
| 8 * num_filters, 16 * num_filters) | |
| else: | |
| self.conv_head_0 = base_conv2d_block( | |
| 8 * num_filters, 16 * num_filters) | |
| self.head_1 = base_res2d_block(16 * num_filters, 16 * num_filters) | |
| self.head_2 = base_res2d_block(16 * num_filters, 16 * num_filters) | |
| self.up_0a = base_res2d_block(16 * num_filters, 8 * num_filters) | |
| if self.use_style_encoder: | |
| self.cbn_up_0a = base_cbn2d_block( | |
| 8 * num_filters, 8 * num_filters) | |
| else: | |
| self.conv_up_0a = base_conv2d_block( | |
| 8 * num_filters, 8 * num_filters) | |
| self.up_0b = base_res2d_block(8 * num_filters, 8 * num_filters) | |
| self.up_1a = base_res2d_block(8 * num_filters, 4 * num_filters) | |
| if self.use_style_encoder: | |
| self.cbn_up_1a = base_cbn2d_block( | |
| 4 * num_filters, 4 * num_filters) | |
| else: | |
| self.conv_up_1a = base_conv2d_block( | |
| 4 * num_filters, 4 * num_filters) | |
| self.up_1b = base_res2d_block(4 * num_filters, 4 * num_filters) | |
| self.up_2a = base_res2d_block(4 * num_filters, 4 * num_filters) | |
| if self.use_style_encoder: | |
| self.cbn_up_2a = base_cbn2d_block( | |
| 4 * num_filters, 4 * num_filters) | |
| else: | |
| self.conv_up_2a = base_conv2d_block( | |
| 4 * num_filters, 4 * num_filters) | |
| self.up_2b = base_res2d_block(4 * num_filters, 2 * num_filters) | |
| self.conv_img256 = Conv2dBlock(2 * num_filters, image_channels, | |
| 5, stride=1, padding=2, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type='none', | |
| nonlinearity=nonlinearity, | |
| order='ANC') | |
| self.base = 16 | |
| if self.out_image_small_side_size == 512: | |
| self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters) | |
| self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters) | |
| self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels, | |
| 5, stride=1, padding=2, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type='none', | |
| nonlinearity=nonlinearity, | |
| order='ANC') | |
| self.base = 32 | |
| if self.out_image_small_side_size == 1024: | |
| self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters) | |
| self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters) | |
| self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels, | |
| 5, stride=1, padding=2, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type='none', | |
| nonlinearity=nonlinearity, | |
| order='ANC') | |
| self.up_4a = base_res2d_block(num_filters, num_filters // 2) | |
| self.up_4b = base_res2d_block(num_filters // 2, num_filters // 2) | |
| self.conv_img1024 = Conv2dBlock(num_filters // 2, image_channels, | |
| 5, stride=1, padding=2, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type='none', | |
| nonlinearity=nonlinearity, | |
| order='ANC') | |
| self.nearest_upsample4x = NearestUpsample(scale_factor=4, mode='nearest') | |
| self.base = 64 | |
| if self.out_image_small_side_size != 256 and self.out_image_small_side_size != 512 \ | |
| and self.out_image_small_side_size != 1024: | |
| raise ValueError('Generation image size (%d, %d) not supported' % | |
| (self.out_image_small_side_size, | |
| self.out_image_small_side_size)) | |
| self.nearest_upsample2x = NearestUpsample(scale_factor=2, mode='nearest') | |
| xv, yv = torch.meshgrid( | |
| [torch.arange(-1, 1.1, 2. / 15), torch.arange(-1, 1.1, 2. / 15)]) | |
| self.xy = torch.cat((xv.unsqueeze(0), yv.unsqueeze(0)), 0).unsqueeze(0) | |
| self.xy = self.xy.cuda() | |
| def forward(self, data): | |
| r"""SPADE Generator forward. | |
| Args: | |
| data (dict): | |
| - data (N x C1 x H x W tensor) : Ground truth images. | |
| - label (N x C2 x H x W tensor) : Semantic representations. | |
| - z (N x style_dims tensor): Gaussian random noise. | |
| Returns: | |
| output (dict): | |
| - fake_images (N x 3 x H x W tensor): Fake images. | |
| """ | |
| seg = data['label'] | |
| if self.use_style_encoder: | |
| z = data['z'] | |
| z = self.fc_0(z) | |
| z = self.fc_1(z) | |
| # The code piece below makes sure that the input size is always 16x16 | |
| sy = math.floor(seg.size()[2] * 1.0 / self.base) | |
| sx = math.floor(seg.size()[3] * 1.0 / self.base) | |
| in_seg = F.interpolate(seg, size=[sy, sx], mode='nearest') | |
| if self.use_posenc_in_input_layer: | |
| in_xy = F.interpolate(self.xy, size=[sy, sx], mode='bicubic') | |
| in_seg_xy = torch.cat( | |
| (in_seg, in_xy.expand(in_seg.size()[0], 2, sy, sx)), 1) | |
| else: | |
| in_seg_xy = in_seg | |
| # 16x16 | |
| x = self.head_0(in_seg_xy) | |
| if self.use_style_encoder: | |
| x = self.cbn_head_0(x, z) | |
| else: | |
| x = self.conv_head_0(x) | |
| x = self.head_1(x, seg) | |
| x = self.head_2(x, seg) | |
| x = self.nearest_upsample2x(x) | |
| # 32x32 | |
| x = self.up_0a(x, seg) | |
| if self.use_style_encoder: | |
| x = self.cbn_up_0a(x, z) | |
| else: | |
| x = self.conv_up_0a(x) | |
| x = self.up_0b(x, seg) | |
| x = self.nearest_upsample2x(x) | |
| # 64x64 | |
| x = self.up_1a(x, seg) | |
| if self.use_style_encoder: | |
| x = self.cbn_up_1a(x, z) | |
| else: | |
| x = self.conv_up_1a(x) | |
| x = self.up_1b(x, seg) | |
| x = self.nearest_upsample2x(x) | |
| # 128x128 | |
| x = self.up_2a(x, seg) | |
| if self.use_style_encoder: | |
| x = self.cbn_up_2a(x, z) | |
| else: | |
| x = self.conv_up_2a(x) | |
| x = self.up_2b(x, seg) | |
| x = self.nearest_upsample2x(x) | |
| # 256x256 | |
| if self.out_image_small_side_size == 256: | |
| x256 = self.conv_img256(x) | |
| x = torch.tanh(self.output_multiplier * x256) | |
| # 512x512 | |
| elif self.out_image_small_side_size == 512: | |
| x256 = self.conv_img256(x) | |
| x256 = self.nearest_upsample2x(x256) | |
| x = self.up_3a(x, seg) | |
| x = self.up_3b(x, seg) | |
| x = self.nearest_upsample2x(x) | |
| x512 = self.conv_img512(x) | |
| x = torch.tanh(self.output_multiplier * (x256 + x512)) | |
| # 1024x1024 | |
| elif self.out_image_small_side_size == 1024: | |
| x256 = self.conv_img256(x) | |
| x256 = self.nearest_upsample4x(x256) | |
| x = self.up_3a(x, seg) | |
| x = self.up_3b(x, seg) | |
| x = self.nearest_upsample2x(x) | |
| x512 = self.conv_img512(x) | |
| x512 = self.nearest_upsample2x(x512) | |
| x = self.up_4a(x, seg) | |
| x = self.up_4b(x, seg) | |
| x = self.nearest_upsample2x(x) | |
| x1024 = self.conv_img1024(x) | |
| x = torch.tanh(self.output_multiplier * (x256 + x512 + x1024)) | |
| output = dict() | |
| output['fake_images'] = x | |
| return output | |
| class StyleEncoder(nn.Module): | |
| r"""Style Encode constructor. | |
| Args: | |
| style_enc_cfg (obj): Style encoder definition file. | |
| """ | |
| def __init__(self, style_enc_cfg): | |
| super(StyleEncoder, self).__init__() | |
| input_image_channels = style_enc_cfg.input_image_channels | |
| num_filters = style_enc_cfg.num_filters | |
| kernel_size = style_enc_cfg.kernel_size | |
| padding = int(np.ceil((kernel_size - 1.0) / 2)) | |
| style_dims = style_enc_cfg.style_dims | |
| weight_norm_type = style_enc_cfg.weight_norm_type | |
| activation_norm_type = 'none' | |
| nonlinearity = 'leakyrelu' | |
| base_conv2d_block = \ | |
| functools.partial(Conv2dBlock, | |
| kernel_size=kernel_size, | |
| stride=2, | |
| padding=padding, | |
| weight_norm_type=weight_norm_type, | |
| activation_norm_type=activation_norm_type, | |
| # inplace_nonlinearity=True, | |
| nonlinearity=nonlinearity) | |
| self.layer1 = base_conv2d_block(input_image_channels, num_filters) | |
| self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) | |
| self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) | |
| self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) | |
| self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) | |
| self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) | |
| self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims) | |
| self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims) | |
| def forward(self, input_x): | |
| r"""SPADE Style Encoder forward. | |
| Args: | |
| input_x (N x 3 x H x W tensor): input images. | |
| Returns: | |
| (tuple): | |
| - mu (N x C tensor): Mean vectors. | |
| - logvar (N x C tensor): Log-variance vectors. | |
| - z (N x C tensor): Style code vectors. | |
| """ | |
| if input_x.size(2) != 256 or input_x.size(3) != 256: | |
| input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') | |
| x = self.layer1(input_x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.layer5(x) | |
| x = self.layer6(x) | |
| x = x.view(x.size(0), -1) | |
| mu = self.fc_mu(x) | |
| logvar = self.fc_var(x) | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| z = eps.mul(std) + mu | |
| return mu, logvar, z | |