Deep-Multi-scale / models /networks.py
GoodWin's picture
Add files
0f691e2
raw
history blame
27.2 kB
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.nn import Parameter as P
from util import util
from torchvision import models
import scipy.io as sio
import numpy as np
import scipy.ndimage
import torch.nn.utils.spectral_norm as SpectralNorm
from torch.autograd import Function
from math import sqrt
import random
import os
import math
from sync_batchnorm import convert_model
####
###############################################################################
# Helper Functions
###############################################################################
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def get_scheduler(optimizer, opt):
if opt.lr_policy == 'lambda':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init_flag=True):
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net = convert_model(net)
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids)
if init_flag:
init_weights(net, init_type, gain=init_gain)
return net
# compute adaptive instance norm
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 3)
C, _ = size[:2]
feat_var = feat.contiguous().view(C, -1).var(dim=1) + eps
feat_std = feat_var.sqrt().view(C, 1, 1)
feat_mean = feat.contiguous().view(C, -1).mean(dim=1).view(C, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat): # content_feat is degraded feature, style is ref feature
assert (content_feat.size()[:1] == style_feat.size()[:1])
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(
size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def calc_mean_std_4D(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
# assert (content_feat.size()[:2] == style_feat.size()[:2])
size = content_feat.size()
style_mean, style_std = calc_mean_std_4D(style_feat)
content_mean, content_std = calc_mean_std_4D(content_feat)
normalized_feat = (content_feat - content_mean.expand(
size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def define_G(which_model_netG, gpu_ids=[]):
if which_model_netG == 'UNetDictFace':
netG = UNetDictFace(64)
init_flag = False
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
return init_net(netG, 'normal', 0.02, gpu_ids, init_flag)
##############################################################################
# Classes
############################################################################################################################################
def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
return nn.Sequential(
SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
# conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias),
# nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2),
SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
)
class MSDilateBlock(nn.Module):
def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
super(MSDilateBlock, self).__init__()
self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(x)
conv3 = self.conv3(x)
conv4 = self.conv4(x)
cat = torch.cat([conv1, conv2, conv3, conv4], 1)
out = self.convi(cat) + x
return out
##############################UNetFace#########################
class AdaptiveInstanceNorm(nn.Module):
def __init__(self, in_channel):
super().__init__()
self.norm = nn.InstanceNorm2d(in_channel)
def forward(self, input, style):
style_mean, style_std = calc_mean_std_4D(style)
out = self.norm(input)
size = input.size()
out = style_std.expand(size) * out + style_mean.expand(size)
return out
class BlurFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)
grad_input = F.conv2d(
grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]
)
return grad_input
@staticmethod
def backward(ctx, gradgrad_output):
kernel, kernel_flip = ctx.saved_tensors
grad_input = F.conv2d(
gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]
)
return grad_input, None, None
class BlurFunction(Function):
@staticmethod
def forward(ctx, input, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)
output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])
return output
@staticmethod
def backward(ctx, grad_output):
kernel, kernel_flip = ctx.saved_tensors
grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
return grad_input, None, None
blur = BlurFunction.apply
class Blur(nn.Module):
def __init__(self, channel):
super().__init__()
weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
weight = weight.view(1, 1, 3, 3)
weight = weight / weight.sum()
weight_flip = torch.flip(weight, [2, 3])
self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))
self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))
def forward(self, input):
return blur(input, self.weight, self.weight_flip)
class EqualLR:
def __init__(self, name):
self.name = name
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
fan_in = weight.data.size(1) * weight.data[0][0].numel()
return weight * sqrt(2 / fan_in)
@staticmethod
def apply(module, name):
fn = EqualLR(name)
weight = getattr(module, name)
del module._parameters[name]
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
module.register_forward_pre_hook(fn)
return fn
def __call__(self, module, input):
weight = self.compute_weight(module)
setattr(module, self.name, weight)
def equal_lr(module, name='weight'):
EqualLR.apply(module, name)
return module
class EqualConv2d(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
conv = nn.Conv2d(*args, **kwargs)
conv.weight.data.normal_()
conv.bias.data.zero_()
self.conv = equal_lr(conv)
def forward(self, input):
return self.conv(input)
class NoiseInjection(nn.Module):
def __init__(self, channel):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
def forward(self, image, noise):
return image + self.weight * noise
class StyledUpBlock(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False):
super().__init__()
if upsample:
self.conv1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
Blur(out_channel),
# EqualConv2d(in_channel, out_channel, kernel_size, padding=padding),
SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
nn.LeakyReLU(0.2),
)
else:
self.conv1 = nn.Sequential(
Blur(in_channel),
# EqualConv2d(in_channel, out_channel, kernel_size, padding=padding)
SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
nn.LeakyReLU(0.2),
)
self.convup = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
# EqualConv2d(out_channel, out_channel, kernel_size, padding=padding),
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
nn.LeakyReLU(0.2),
# Blur(out_channel),
)
# self.noise1 = equal_lr(NoiseInjection(out_channel))
# self.adain1 = AdaptiveInstanceNorm(out_channel)
self.lrelu1 = nn.LeakyReLU(0.2)
# self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
# self.noise2 = equal_lr(NoiseInjection(out_channel))
# self.adain2 = AdaptiveInstanceNorm(out_channel)
# self.lrelu2 = nn.LeakyReLU(0.2)
self.ScaleModel1 = nn.Sequential(
# Blur(in_channel),
SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
# nn.Conv2d(in_channel,out_channel,3, 1, 1),
nn.LeakyReLU(0.2, True),
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
# nn.Conv2d(out_channel, out_channel, 3, 1, 1)
)
self.ShiftModel1 = nn.Sequential(
# Blur(in_channel),
SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
# nn.Conv2d(in_channel,out_channel,3, 1, 1),
nn.LeakyReLU(0.2, True),
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
nn.Sigmoid(),
# nn.Conv2d(out_channel, out_channel, 3, 1, 1)
)
def forward(self, input, style):
out = self.conv1(input)
# out = self.noise1(out, noise)
out = self.lrelu1(out)
Shift1 = self.ShiftModel1(style)
Scale1 = self.ScaleModel1(style)
out = out * Scale1 + Shift1
# out = self.adain1(out, style)
outup = self.convup(out)
return outup
##############################################################################
##Face Dictionary
##############################################################################
class VGGFeat(torch.nn.Module):
"""
Input: (B, C, H, W), RGB, [-1, 1]
"""
def __init__(self, weight_path='./weights/vgg19.pth'):
super().__init__()
self.model = models.vgg19(pretrained=False)
self.build_vgg_layers()
self.model.load_state_dict(torch.load(weight_path))
self.register_parameter("RGB_mean", nn.Parameter(torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)))
self.register_parameter("RGB_std", nn.Parameter(torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)))
# self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
def build_vgg_layers(self):
vgg_pretrained_features = self.model.features
self.features = []
# feature_layers = [0, 3, 8, 17, 26, 35]
feature_layers = [0, 8, 17, 26, 35]
for i in range(len(feature_layers)-1):
module_layers = torch.nn.Sequential()
for j in range(feature_layers[i], feature_layers[i+1]):
module_layers.add_module(str(j), vgg_pretrained_features[j])
self.features.append(module_layers)
self.features = torch.nn.ModuleList(self.features)
def preprocess(self, x):
x = (x + 1) / 2
x = (x - self.RGB_mean) / self.RGB_std
if x.shape[3] < 224:
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
return x
def forward(self, x):
x = self.preprocess(x)
features = []
for m in self.features:
# print(m)
x = m(x)
features.append(x)
return features
def compute_sum(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.sum(x, dim=i, keepdim=keepdim)
return x
def ToRGB(in_channel):
return nn.Sequential(
SpectralNorm(nn.Conv2d(in_channel,in_channel,3, 1, 1)),
nn.LeakyReLU(0.2),
SpectralNorm(nn.Conv2d(in_channel,3,3, 1, 1))
)
def AttentionBlock(in_channel):
return nn.Sequential(
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
nn.LeakyReLU(0.2),
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))
)
class UNetDictFace(nn.Module):
def __init__(self, ngf=64, dictionary_path='./DictionaryCenter512'):
super().__init__()
self.part_sizes = np.array([80,80,50,110]) # size for 512
self.feature_sizes = np.array([256,128,64,32])
self.channel_sizes = np.array([128,256,512,512])
Parts = ['left_eye','right_eye','nose','mouth']
self.Dict_256 = {}
self.Dict_128 = {}
self.Dict_64 = {}
self.Dict_32 = {}
for j,i in enumerate(Parts):
f_256 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_256_center.npy'.format(i)), allow_pickle=True))
f_256_reshape = f_256.reshape(f_256.size(0),self.channel_sizes[0],self.part_sizes[j]//2,self.part_sizes[j]//2)
max_256 = torch.max(torch.sqrt(compute_sum(torch.pow(f_256_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
self.Dict_256[i] = f_256_reshape #/ max_256
f_128 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_128_center.npy'.format(i)), allow_pickle=True))
f_128_reshape = f_128.reshape(f_128.size(0),self.channel_sizes[1],self.part_sizes[j]//4,self.part_sizes[j]//4)
max_128 = torch.max(torch.sqrt(compute_sum(torch.pow(f_128_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
self.Dict_128[i] = f_128_reshape #/ max_128
f_64 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_64_center.npy'.format(i)), allow_pickle=True))
f_64_reshape = f_64.reshape(f_64.size(0),self.channel_sizes[2],self.part_sizes[j]//8,self.part_sizes[j]//8)
max_64 = torch.max(torch.sqrt(compute_sum(torch.pow(f_64_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
self.Dict_64[i] = f_64_reshape #/ max_64
f_32 = torch.from_numpy(np.load(os.path.join(dictionary_path, '{}_32_center.npy'.format(i)), allow_pickle=True))
f_32_reshape = f_32.reshape(f_32.size(0),self.channel_sizes[3],self.part_sizes[j]//16,self.part_sizes[j]//16)
max_32 = torch.max(torch.sqrt(compute_sum(torch.pow(f_32_reshape, 2), axis=[1, 2, 3], keepdim=True)),torch.FloatTensor([1e-4]))
self.Dict_32[i] = f_32_reshape #/ max_32
self.le_256 = AttentionBlock(128)
self.le_128 = AttentionBlock(256)
self.le_64 = AttentionBlock(512)
self.le_32 = AttentionBlock(512)
self.re_256 = AttentionBlock(128)
self.re_128 = AttentionBlock(256)
self.re_64 = AttentionBlock(512)
self.re_32 = AttentionBlock(512)
self.no_256 = AttentionBlock(128)
self.no_128 = AttentionBlock(256)
self.no_64 = AttentionBlock(512)
self.no_32 = AttentionBlock(512)
self.mo_256 = AttentionBlock(128)
self.mo_128 = AttentionBlock(256)
self.mo_64 = AttentionBlock(512)
self.mo_32 = AttentionBlock(512)
#norm
self.VggExtract = VGGFeat()
######################
self.MSDilate = MSDilateBlock(ngf*8, dilation = [4,3,2,1]) #
self.up0 = StyledUpBlock(ngf*8,ngf*8)
self.up1 = StyledUpBlock(ngf*8, ngf*4) #
self.up2 = StyledUpBlock(ngf*4, ngf*2) #
self.up3 = StyledUpBlock(ngf*2, ngf) #
self.up4 = nn.Sequential( # 128
# nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
# nn.BatchNorm2d(32),
nn.LeakyReLU(0.2),
UpResBlock(ngf),
UpResBlock(ngf),
# SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh()
)
self.to_rgb0 = ToRGB(ngf*8)
self.to_rgb1 = ToRGB(ngf*4)
self.to_rgb2 = ToRGB(ngf*2)
self.to_rgb3 = ToRGB(ngf*1)
# for param in self.BlurInputConv.parameters():
# param.requires_grad = False
def forward(self,input, part_locations):
VggFeatures = self.VggExtract(input)
# for b in range(input.size(0)):
b = 0
UpdateVggFeatures = []
for i, f_size in enumerate(self.feature_sizes):
cur_feature = VggFeatures[i]
update_feature = cur_feature.clone() #* 0
cur_part_sizes = self.part_sizes // (512/f_size)
dicts_feature = getattr(self, 'Dict_'+str(f_size))
LE_Dict_feature = dicts_feature['left_eye'].to(input)
RE_Dict_feature = dicts_feature['right_eye'].to(input)
NO_Dict_feature = dicts_feature['nose'].to(input)
MO_Dict_feature = dicts_feature['mouth'].to(input)
le_location = (part_locations[0][b] // (512/f_size)).int()
re_location = (part_locations[1][b] // (512/f_size)).int()
no_location = (part_locations[2][b] // (512/f_size)).int()
mo_location = (part_locations[3][b] // (512/f_size)).int()
LE_feature = cur_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]].clone()
RE_feature = cur_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]].clone()
NO_feature = cur_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]].clone()
MO_feature = cur_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]].clone()
#resize
LE_feature_resize = F.interpolate(LE_feature,(LE_Dict_feature.size(2),LE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
RE_feature_resize = F.interpolate(RE_feature,(RE_Dict_feature.size(2),RE_Dict_feature.size(3)),mode='bilinear',align_corners=False)
NO_feature_resize = F.interpolate(NO_feature,(NO_Dict_feature.size(2),NO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
MO_feature_resize = F.interpolate(MO_feature,(MO_Dict_feature.size(2),MO_Dict_feature.size(3)),mode='bilinear',align_corners=False)
LE_Dict_feature_norm = adaptive_instance_normalization_4D(LE_Dict_feature, LE_feature_resize)
RE_Dict_feature_norm = adaptive_instance_normalization_4D(RE_Dict_feature, RE_feature_resize)
NO_Dict_feature_norm = adaptive_instance_normalization_4D(NO_Dict_feature, NO_feature_resize)
MO_Dict_feature_norm = adaptive_instance_normalization_4D(MO_Dict_feature, MO_feature_resize)
LE_score = F.conv2d(LE_feature_resize, LE_Dict_feature_norm)
LE_score = F.softmax(LE_score.view(-1),dim=0)
LE_index = torch.argmax(LE_score)
LE_Swap_feature = F.interpolate(LE_Dict_feature_norm[LE_index:LE_index+1], (LE_feature.size(2), LE_feature.size(3)))
LE_Attention = getattr(self, 'le_'+str(f_size))(LE_Swap_feature-LE_feature)
LE_Att_feature = LE_Attention * LE_Swap_feature
RE_score = F.conv2d(RE_feature_resize, RE_Dict_feature_norm)
RE_score = F.softmax(RE_score.view(-1),dim=0)
RE_index = torch.argmax(RE_score)
RE_Swap_feature = F.interpolate(RE_Dict_feature_norm[RE_index:RE_index+1], (RE_feature.size(2), RE_feature.size(3)))
RE_Attention = getattr(self, 're_'+str(f_size))(RE_Swap_feature-RE_feature)
RE_Att_feature = RE_Attention * RE_Swap_feature
NO_score = F.conv2d(NO_feature_resize, NO_Dict_feature_norm)
NO_score = F.softmax(NO_score.view(-1),dim=0)
NO_index = torch.argmax(NO_score)
NO_Swap_feature = F.interpolate(NO_Dict_feature_norm[NO_index:NO_index+1], (NO_feature.size(2), NO_feature.size(3)))
NO_Attention = getattr(self, 'no_'+str(f_size))(NO_Swap_feature-NO_feature)
NO_Att_feature = NO_Attention * NO_Swap_feature
MO_score = F.conv2d(MO_feature_resize, MO_Dict_feature_norm)
MO_score = F.softmax(MO_score.view(-1),dim=0)
MO_index = torch.argmax(MO_score)
MO_Swap_feature = F.interpolate(MO_Dict_feature_norm[MO_index:MO_index+1], (MO_feature.size(2), MO_feature.size(3)))
MO_Attention = getattr(self, 'mo_'+str(f_size))(MO_Swap_feature-MO_feature)
MO_Att_feature = MO_Attention * MO_Swap_feature
update_feature[:,:,le_location[1]:le_location[3],le_location[0]:le_location[2]] = LE_Att_feature + LE_feature
update_feature[:,:,re_location[1]:re_location[3],re_location[0]:re_location[2]] = RE_Att_feature + RE_feature
update_feature[:,:,no_location[1]:no_location[3],no_location[0]:no_location[2]] = NO_Att_feature + NO_feature
update_feature[:,:,mo_location[1]:mo_location[3],mo_location[0]:mo_location[2]] = MO_Att_feature + MO_feature
UpdateVggFeatures.append(update_feature)
fea_vgg = self.MSDilate(VggFeatures[3])
#new version
fea_up0 = self.up0(fea_vgg, UpdateVggFeatures[3])
# out1 = F.interpolate(fea_up0,(512,512))
# out1 = self.to_rgb0(out1)
fea_up1 = self.up1( fea_up0, UpdateVggFeatures[2]) #
# out2 = F.interpolate(fea_up1,(512,512))
# out2 = self.to_rgb1(out2)
fea_up2 = self.up2(fea_up1, UpdateVggFeatures[1]) #
# out3 = F.interpolate(fea_up2,(512,512))
# out3 = self.to_rgb2(out3)
fea_up3 = self.up3(fea_up2, UpdateVggFeatures[0]) #
# out4 = F.interpolate(fea_up3,(512,512))
# out4 = self.to_rgb3(out4)
output = self.up4(fea_up3) #
return output #+ out4 + out3 + out2 + out1
#0 128 * 256 * 256
#1 256 * 128 * 128
#2 512 * 64 * 64
#3 512 * 32 * 32
class UpResBlock(nn.Module):
def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
super(UpResBlock, self).__init__()
self.Model = nn.Sequential(
# SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
conv_layer(dim, dim, 3, 1, 1),
# norm_layer(dim),
nn.LeakyReLU(0.2,True),
# SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
conv_layer(dim, dim, 3, 1, 1),
)
def forward(self, x):
out = x + self.Model(x)
return out
class VggClassNet(nn.Module):
def __init__(self, select_layer = ['0','5','10','19']):
super(VggClassNet, self).__init__()
self.select = select_layer
self.vgg = models.vgg19(pretrained=True).features
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
features = []
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in self.select:
features.append(x)
return features
if __name__ == '__main__':
print('this is network')