Spaces:
Build error
Build error
| import numpy as np | |
| import torch.nn as nn | |
| from models.model_blocks import ResBlock | |
| class Discriminator(nn.Module): | |
| def __init__(self, input_nc, ndf=64, n_layers=6): | |
| super(Discriminator, self).__init__() | |
| sequence = [nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1)] | |
| for i in range(n_layers): | |
| if i >= 3: | |
| sequence += [ResBlock(512, 512, down_sample=True, norm=False)] | |
| else: | |
| mult = 2**i | |
| sequence += [ResBlock(ndf * mult, ndf * mult * 2, down_sample=True, norm=False)] | |
| sequence += [ | |
| nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=0), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(512, 2, kernel_size=1, stride=1, padding=0), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ] | |
| self.sequence = nn.Sequential(*sequence) | |
| def forward(self, input): | |
| return self.sequence(input) | |