Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import warnings | |
| import torch | |
| from torch import nn | |
| from imaginaire.layers import Conv2dBlock, Res2dBlock | |
| class Discriminator(nn.Module): | |
| r"""Discriminator in the improved FUNIT baseline in the COCO-FUNIT paper. | |
| Args: | |
| dis_cfg (obj): Discriminator definition part of the yaml config file. | |
| data_cfg (obj): Data definition part of the yaml config file. | |
| """ | |
| def __init__(self, dis_cfg, data_cfg): | |
| super().__init__() | |
| self.model = ResDiscriminator(**vars(dis_cfg)) | |
| def forward(self, data, net_G_output, recon=True): | |
| r"""Improved FUNIT discriminator forward function. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| net_G_output (dict): Fake data generated at the current iteration. | |
| recon (bool): If ``True``, also classifies reconstructed images. | |
| """ | |
| source_labels = data['labels_content'] | |
| target_labels = data['labels_style'] | |
| fake_out_trans, fake_features_trans = \ | |
| self.model(net_G_output['images_trans'], target_labels) | |
| output = dict(fake_out_trans=fake_out_trans, | |
| fake_features_trans=fake_features_trans) | |
| real_out_style, real_features_style = \ | |
| self.model(data['images_style'], target_labels) | |
| output.update(dict(real_out_style=real_out_style, | |
| real_features_style=real_features_style)) | |
| if recon: | |
| fake_out_recon, fake_features_recon = \ | |
| self.model(net_G_output['images_recon'], source_labels) | |
| output.update(dict(fake_out_recon=fake_out_recon, | |
| fake_features_recon=fake_features_recon)) | |
| return output | |
| class ResDiscriminator(nn.Module): | |
| r"""Residual discriminator architecture used in the FUNIT paper.""" | |
| def __init__(self, | |
| image_channels=3, | |
| num_classes=119, | |
| num_filters=64, | |
| max_num_filters=1024, | |
| num_layers=6, | |
| padding_mode='reflect', | |
| weight_norm_type='', | |
| **kwargs): | |
| super().__init__() | |
| for key in kwargs: | |
| if key != 'type': | |
| warnings.warn( | |
| "Discriminator argument {} is not used".format(key)) | |
| conv_params = dict(padding_mode=padding_mode, | |
| activation_norm_type='none', | |
| weight_norm_type=weight_norm_type, | |
| bias=[True, True, True], | |
| nonlinearity='leakyrelu', | |
| order='NACNAC') | |
| first_kernel_size = 7 | |
| first_padding = (first_kernel_size - 1) // 2 | |
| model = [Conv2dBlock(image_channels, num_filters, | |
| first_kernel_size, 1, first_padding, | |
| padding_mode=padding_mode, | |
| weight_norm_type=weight_norm_type)] | |
| for i in range(num_layers): | |
| num_filters_prev = num_filters | |
| num_filters = min(num_filters * 2, max_num_filters) | |
| model += [Res2dBlock(num_filters_prev, num_filters_prev, | |
| **conv_params), | |
| Res2dBlock(num_filters_prev, num_filters, | |
| **conv_params)] | |
| if i != num_layers - 1: | |
| model += [nn.ReflectionPad2d(1), | |
| nn.AvgPool2d(3, stride=2)] | |
| self.model = nn.Sequential(*model) | |
| self.classifier = Conv2dBlock(num_filters, 1, 1, 1, 0, | |
| nonlinearity='leakyrelu', | |
| weight_norm_type=weight_norm_type, | |
| order='NACNAC') | |
| self.embedder = nn.Embedding(num_classes, num_filters) | |
| def forward(self, images, labels=None): | |
| r"""Forward function of the projection discriminator. | |
| Args: | |
| images (image tensor): Images inputted to the discriminator. | |
| labels (long int tensor): Class labels of the images. | |
| """ | |
| assert (images.size(0) == labels.size(0)) | |
| features = self.model(images) | |
| outputs = self.classifier(features) | |
| features_1x1 = features.mean(3).mean(2) | |
| if labels is None: | |
| return features_1x1 | |
| embeddings = self.embedder(labels) | |
| outputs += torch.sum(embeddings * features_1x1, dim=1, | |
| keepdim=True).view(images.size(0), 1, 1, 1) | |
| return outputs, features_1x1 | |