Spaces:
Build error
Build error
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_channel, out_channel, down_sample=False, up_sample=False, norm=True): | |
| super(ResBlock, self).__init__() | |
| main_module_list = [] | |
| if norm: | |
| main_module_list += [ | |
| nn.InstanceNorm2d(in_channel), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), | |
| ] | |
| else: | |
| main_module_list += [ | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), | |
| ] | |
| if down_sample: | |
| main_module_list.append(nn.AvgPool2d(kernel_size=2)) | |
| elif up_sample: | |
| main_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear")) | |
| if norm: | |
| main_module_list += [ | |
| nn.InstanceNorm2d(out_channel), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1), | |
| ] | |
| else: | |
| main_module_list += [ | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1), | |
| ] | |
| self.main_path = nn.Sequential(*main_module_list) | |
| side_module_list = [nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)] | |
| if down_sample: | |
| side_module_list.append(nn.AvgPool2d(kernel_size=2)) | |
| elif up_sample: | |
| side_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear")) | |
| self.side_path = nn.Sequential(*side_module_list) | |
| def forward(self, x): | |
| x1 = self.main_path(x) | |
| x2 = self.side_path(x) | |
| return x1 + x2 | |
| class AdaIn(nn.Module): | |
| def __init__(self, in_channel, vector_size): | |
| super(AdaIn, self).__init__() | |
| self.eps = 1e-5 | |
| self.std_style_fc = nn.Linear(vector_size, in_channel) | |
| self.mean_style_fc = nn.Linear(vector_size, in_channel) | |
| def forward(self, x, style_vector): | |
| std_style = self.std_style_fc(style_vector) | |
| mean_style = self.mean_style_fc(style_vector) | |
| std_style = std_style.unsqueeze(-1).unsqueeze(-1) | |
| mean_style = mean_style.unsqueeze(-1).unsqueeze(-1) | |
| x = F.instance_norm(x) | |
| x = std_style * x + mean_style | |
| return x | |
| class AdaInResBlock(nn.Module): | |
| def __init__(self, in_channel, out_channel, up_sample=False): | |
| super(AdaInResBlock, self).__init__() | |
| self.vector_size = 257 + 512 | |
| self.up_sample = up_sample | |
| self.adain1 = AdaIn(in_channel, self.vector_size) | |
| self.adain2 = AdaIn(out_channel, self.vector_size) | |
| main_module_list = [] | |
| main_module_list += [ | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), | |
| ] | |
| if up_sample: | |
| main_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear")) | |
| self.main_path1 = nn.Sequential(*main_module_list) | |
| self.main_path2 = nn.Sequential( | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1), | |
| ) | |
| side_module_list = [nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)] | |
| if up_sample: | |
| side_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear")) | |
| self.side_path = nn.Sequential(*side_module_list) | |
| def forward(self, x, id_vector): | |
| x1 = self.adain1(x, id_vector) | |
| x1 = self.main_path1(x1) | |
| x2 = self.side_path(x) | |
| x1 = self.adain2(x1, id_vector) | |
| x1 = self.main_path2(x1) | |
| return x1 + x2 | |
| class UpSamplingBlock(nn.Module): | |
| def __init__( | |
| self, | |
| ): | |
| super(UpSamplingBlock, self).__init__() | |
| self.net = nn.Sequential(ResBlock(256, 256, up_sample=True), ResBlock(256, 256, up_sample=True)) | |
| self.i_r_net = nn.Sequential(nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 3, 3, 1, 1)) | |
| self.m_r_net = nn.Sequential(nn.Conv2d(256, 1, 3, 1, 1), nn.Sigmoid()) | |
| def forward(self, x): | |
| x = self.net(x) | |
| i_r = self.i_r_net(x) | |
| m_r = self.m_r_net(x) | |
| return i_r, m_r | |