Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| def weights_init(m): | |
| classname = m.__class__.__name__ | |
| if classname.find("Conv") != -1: | |
| m.weight.data.normal_(0.0, 0.02) | |
| elif classname.find("BatchNorm2d") != -1: | |
| m.weight.data.normal_(1.0, 0.02) | |
| m.bias.data.fill_(0) | |
| class MultiscaleDiscriminator(nn.Module): | |
| def __init__( | |
| self, | |
| input_nc, | |
| ndf=64, | |
| n_layers=3, | |
| norm_layer=nn.BatchNorm2d, | |
| use_sigmoid=False, | |
| num_D=3, | |
| getIntermFeat=False, | |
| finetune=False, | |
| ): | |
| super(MultiscaleDiscriminator, self).__init__() | |
| self.num_D = num_D | |
| self.n_layers = n_layers | |
| self.getIntermFeat = getIntermFeat | |
| for i in range(num_D): | |
| netD = NLayerDiscriminator( | |
| input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat | |
| ) | |
| if getIntermFeat: | |
| for j in range(n_layers + 2): | |
| setattr( | |
| self, | |
| "scale" + str(i) + "_layer" + str(j), | |
| getattr(netD, "model" + str(j)), | |
| ) | |
| else: | |
| setattr(self, "layer" + str(i), netD.model) | |
| self.downsample = nn.AvgPool2d( | |
| 3, stride=2, padding=[1, 1], count_include_pad=False | |
| ) | |
| weights_init(self) | |
| if finetune: | |
| self.requires_grad_(False) | |
| for name, param in self.named_parameters(): | |
| if 'layer0' in name: | |
| param.requires_grad = True | |
| def singleD_forward(self, model, input): | |
| if self.getIntermFeat: | |
| result = [input] | |
| for i in range(len(model)): | |
| result.append(model[i](result[-1])) | |
| return result[1:] | |
| else: | |
| return [model(input)] | |
| def forward(self, input): | |
| num_D = self.num_D | |
| result = [] | |
| input_downsampled = input | |
| for i in range(num_D): | |
| if self.getIntermFeat: | |
| model = [ | |
| getattr(self, "scale" + str(num_D - 1 - i) + "_layer" + str(j)) | |
| for j in range(self.n_layers + 2) | |
| ] | |
| else: | |
| model = getattr(self, "layer" + str(num_D - 1 - i)) | |
| result.append(self.singleD_forward(model, input_downsampled)) | |
| if i != (num_D - 1): | |
| input_downsampled = self.downsample(input_downsampled) | |
| return result | |
| # Defines the PatchGAN discriminator with the specified arguments. | |
| class NLayerDiscriminator(nn.Module): | |
| def __init__( | |
| self, | |
| input_nc, | |
| ndf=64, | |
| n_layers=3, | |
| norm_layer=nn.BatchNorm2d, | |
| use_sigmoid=False, | |
| getIntermFeat=False, | |
| ): | |
| super(NLayerDiscriminator, self).__init__() | |
| self.getIntermFeat = getIntermFeat | |
| self.n_layers = n_layers | |
| kw = 4 | |
| padw = int(np.ceil((kw - 1.0) / 2)) | |
| sequence = [ | |
| [ | |
| nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), | |
| nn.LeakyReLU(0.2, True), | |
| ] | |
| ] | |
| nf = ndf | |
| for n in range(1, n_layers): | |
| nf_prev = nf | |
| nf = min(nf * 2, 512) | |
| sequence += [ | |
| [ | |
| nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), | |
| norm_layer(nf), | |
| nn.LeakyReLU(0.2, True), | |
| ] | |
| ] | |
| nf_prev = nf | |
| nf = min(nf * 2, 512) | |
| sequence += [ | |
| [ | |
| nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), | |
| norm_layer(nf), | |
| nn.LeakyReLU(0.2, True), | |
| ] | |
| ] | |
| sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] | |
| if use_sigmoid: | |
| sequence += [[nn.Sigmoid()]] | |
| if getIntermFeat: | |
| for n in range(len(sequence)): | |
| setattr(self, "model" + str(n), nn.Sequential(*sequence[n])) | |
| else: | |
| sequence_stream = [] | |
| for n in range(len(sequence)): | |
| sequence_stream += sequence[n] | |
| self.model = nn.Sequential(*sequence_stream) | |
| def forward(self, input): | |
| if self.getIntermFeat: | |
| res = [input] | |
| for n in range(self.n_layers + 2): | |
| model = getattr(self, "model" + str(n)) | |
| res.append(model(res[-1])) | |
| return res[1:] | |
| else: | |
| return self.model(input) | |