Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from torch.nn import init | |
| import functools | |
| from torch.optim import lr_scheduler | |
| from .c2pGen import * | |
| from .p2cGen import * | |
| from .c2pDis import * | |
| class Identity(nn.Module): | |
| def forward(self, x): | |
| return x | |
| def get_norm_layer(norm_type='instance'): | |
| """Return a normalization layer | |
| Parameters: | |
| norm_type (str) -- the name of the normalization layer: batch | instance | none | |
| For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). | |
| For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. | |
| """ | |
| if norm_type == 'batch': | |
| norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) | |
| elif norm_type == 'instance': | |
| norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) | |
| elif norm_type == 'none': | |
| def norm_layer(x): return Identity() | |
| else: | |
| raise NotImplementedError('normalization layer [%s] is not found' % norm_type) | |
| return norm_layer | |
| def get_scheduler(optimizer, opt): | |
| """Return a learning rate scheduler | |
| Parameters: | |
| optimizer -- the optimizer of the network | |
| opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. | |
| opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine | |
| For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs | |
| and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs. | |
| For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. | |
| See https://pytorch.org/docs/stable/optim.html for more details. | |
| """ | |
| if opt.lr_policy == 'linear': | |
| def lambda_rule(epoch): | |
| lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) | |
| return lr_l | |
| scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) | |
| elif opt.lr_policy == 'step': | |
| scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) | |
| elif opt.lr_policy == 'plateau': | |
| scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) | |
| elif opt.lr_policy == 'cosine': | |
| scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) | |
| else: | |
| return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) | |
| return scheduler | |
| def init_weights(net, init_type='normal', init_gain=0.02): | |
| """Initialize network weights. | |
| Parameters: | |
| net (network) -- network to be initialized | |
| init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
| init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
| """ | |
| def init_func(m): # define the initialization function | |
| classname = m.__class__.__name__ | |
| if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): | |
| if init_type == 'normal': | |
| init.normal_(m.weight.data, 0.0, init_gain) | |
| elif init_type == 'xavier': | |
| init.xavier_normal_(m.weight.data, gain=init_gain) | |
| elif init_type == 'kaiming': | |
| init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
| elif init_type == 'orthogonal': | |
| init.orthogonal_(m.weight.data, gain=init_gain) | |
| else: | |
| raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| init.constant_(m.bias.data, 0.0) | |
| elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. | |
| init.normal_(m.weight.data, 1.0, init_gain) | |
| init.constant_(m.bias.data, 0.0) | |
| #print('initialize network with %s' % init_type) | |
| net.apply(init_func) # apply the initialization function <init_func> | |
| def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): | |
| """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights | |
| Parameters: | |
| net (network) -- the network to be initialized | |
| init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
| gain (float) -- scaling factor for normal, xavier and orthogonal. | |
| gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 | |
| Return an initialized network. | |
| """ | |
| gpu_ids = [0] | |
| if len(gpu_ids) > 0: | |
| # assert(torch.cuda.is_available()) #uncomment this for using gpu | |
| net.to(torch.device("cpu")) #change this for using gpu to gpu_ids[0] | |
| net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs | |
| init_weights(net, init_type, init_gain=init_gain) | |
| return net | |
| def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): | |
| """Create a generator | |
| Parameters: | |
| input_nc (int) -- the number of channels in input images | |
| output_nc (int) -- the number of channels in output images | |
| ngf (int) -- the number of filters in the last conv layer | |
| netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 | |
| norm (str) -- the name of normalization layers used in the network: batch | instance | none | |
| use_dropout (bool) -- if use dropout layers. | |
| init_type (str) -- the name of our initialization method. | |
| init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
| gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 | |
| Returns a generator | |
| """ | |
| net = None | |
| norm_layer = get_norm_layer(norm_type=norm) | |
| if netG == 'c2pGen': # style_dim mlp_dim | |
| net = C2PGen(input_nc, output_nc, ngf, 2, 4, 256, 256, activ='relu', pad_type='reflect') | |
| #print('c2pgen resblock is 8') | |
| elif netG == 'p2cGen': | |
| net = P2CGen(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect') | |
| elif netG == 'antialias': | |
| net = AliasNet(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect') | |
| else: | |
| raise NotImplementedError('Generator model name [%s] is not recognized' % netG) | |
| return init_net(net, init_type, init_gain, gpu_ids) | |
| def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): | |
| """Create a discriminator | |
| Parameters: | |
| input_nc (int) -- the number of channels in input images | |
| ndf (int) -- the number of filters in the first conv layer | |
| netD (str) -- the architecture's name: basic | n_layers | pixel | |
| n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' | |
| norm (str) -- the type of normalization layers used in the network. | |
| init_type (str) -- the name of the initialization method. | |
| init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
| gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 | |
| Returns a discriminator | |
| """ | |
| net = None | |
| norm_layer = get_norm_layer(norm_type=norm) | |
| if netD == 'CPDis': | |
| net = CPDis(image_size=256, conv_dim=64, repeat_num=3, norm='SN') | |
| elif netD == 'CPDis_cls': | |
| net = CPDis_cls(image_size=256, conv_dim=64, repeat_num=3, norm='SN') | |
| else: | |
| raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) | |
| return init_net(net, init_type, init_gain, gpu_ids) | |
| class GANLoss(nn.Module): | |
| """Define different GAN objectives. | |
| The GANLoss class abstracts away the need to create the target label tensor | |
| that has the same size as the input. | |
| """ | |
| def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): | |
| """ Initialize the GANLoss class. | |
| Parameters: | |
| gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. | |
| target_real_label (bool) - - label for a real image | |
| target_fake_label (bool) - - label of a fake image | |
| Note: Do not use sigmoid as the last layer of Discriminator. | |
| LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. | |
| """ | |
| super(GANLoss, self).__init__() | |
| self.register_buffer('real_label', torch.tensor(target_real_label)) | |
| self.register_buffer('fake_label', torch.tensor(target_fake_label)) | |
| self.gan_mode = gan_mode | |
| if gan_mode == 'lsgan': | |
| self.loss = nn.MSELoss() | |
| elif gan_mode == 'vanilla': | |
| self.loss = nn.BCEWithLogitsLoss() | |
| elif gan_mode in ['wgangp']: | |
| self.loss = None | |
| else: | |
| raise NotImplementedError('gan mode %s not implemented' % gan_mode) | |
| def get_target_tensor(self, prediction, target_is_real): | |
| """Create label tensors with the same size as the input. | |
| Parameters: | |
| prediction (tensor) - - tpyically the prediction from a discriminator | |
| target_is_real (bool) - - if the ground truth label is for real images or fake images | |
| Returns: | |
| A label tensor filled with ground truth label, and with the size of the input | |
| """ | |
| if target_is_real: | |
| target_tensor = self.real_label | |
| else: | |
| target_tensor = self.fake_label | |
| return target_tensor.expand_as(prediction) | |
| def __call__(self, prediction, target_is_real): | |
| """Calculate loss given Discriminator's output and grount truth labels. | |
| Parameters: | |
| prediction (tensor) - - tpyically the prediction output from a discriminator | |
| target_is_real (bool) - - if the ground truth label is for real images or fake images | |
| Returns: | |
| the calculated loss. | |
| """ | |
| if self.gan_mode in ['lsgan', 'vanilla']: | |
| target_tensor = self.get_target_tensor(prediction, target_is_real) | |
| loss = self.loss(prediction, target_tensor) | |
| elif self.gan_mode == 'wgangp': | |
| if target_is_real: | |
| loss = -prediction.mean() | |
| else: | |
| loss = prediction.mean() | |
| return loss | |