Spaces:
Running
Running
| from .basic_layer import * | |
| import math | |
| from torch.nn import Parameter | |
| #from pytorch_metric_learning import losses | |
| ''' | |
| Margin code is borrowed from https://github.com/MuggleWang/CosFace_pytorch and https://github.com/wujiyang/Face_Pytorch. | |
| ''' | |
| def cosine_sim(x1, x2, dim=1, eps=1e-8): | |
| ip = torch.mm(x1, x2.t()) # w 7*512 | |
| w1 = torch.norm(x1, 2, dim) | |
| w2 = torch.norm(x2, 2, dim) | |
| return ip / torch.ger(w1,w2).clamp(min=eps) | |
| class MarginCosineProduct(nn.Module): | |
| r"""Implement of large margin cosine distance: : | |
| Args: | |
| in_features: size of each input sample | |
| out_features: size of each output sample | |
| s: norm of input feature | |
| m: margin | |
| """ | |
| def __init__(self, in_features, out_features, s=30.0, m=0.40): | |
| super(MarginCosineProduct, self).__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.s = s | |
| self.m = m | |
| self.weight = Parameter(torch.Tensor(out_features, in_features)) # 7 512 | |
| nn.init.xavier_uniform_(self.weight) | |
| #stdv = 1. / math.sqrt(self.weight.size(1)) | |
| #self.weight.data.uniform_(-stdv, stdv) | |
| def forward(self, input, label): | |
| cosine = cosine_sim(input, self.weight) # 1*512 7*512 | |
| # cosine = F.linear(F.normalize(input), F.normalize(self.weight)) | |
| # --------------------------- convert label to one-hot --------------------------- | |
| # https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507 | |
| one_hot = torch.zeros_like(cosine) | |
| one_hot.scatter_(1, label.view(-1, 1), 1.0) | |
| # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- | |
| output = self.s * (cosine - one_hot * self.m) | |
| return output | |
| def __repr__(self): | |
| return self.__class__.__name__ + '(' \ | |
| + 'in_features=' + str(self.in_features) \ | |
| + ', out_features=' + str(self.out_features) \ | |
| + ', s=' + str(self.s) \ | |
| + ', m=' + str(self.m) + ')' | |
| class ArcMarginProduct(nn.Module): | |
| def __init__(self, in_feature=128, out_feature=10575, s=32.0, m=0.50, easy_margin=False): | |
| super(ArcMarginProduct, self).__init__() | |
| self.in_feature = in_feature | |
| self.out_feature = out_feature | |
| self.s = s | |
| self.m = m | |
| self.weight = Parameter(torch.Tensor(out_feature, in_feature)) | |
| nn.init.xavier_uniform_(self.weight) | |
| self.easy_margin = easy_margin | |
| self.cos_m = math.cos(m) | |
| self.sin_m = math.sin(m) | |
| # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] | |
| self.th = math.cos(math.pi - m) | |
| self.mm = math.sin(math.pi - m) * m | |
| def forward(self, x, label): | |
| # cos(theta) | |
| cosine = F.linear(F.normalize(x), F.normalize(self.weight)) | |
| # cos(theta + m) | |
| sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) | |
| phi = cosine * self.cos_m - sine * self.sin_m | |
| if self.easy_margin: | |
| phi = torch.where(cosine > 0, phi, cosine) | |
| else: | |
| phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) | |
| #one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu') | |
| one_hot = torch.zeros_like(cosine) | |
| one_hot.scatter_(1, label.view(-1, 1), 1) | |
| output = (one_hot * phi) + ((1.0 - one_hot) * cosine) | |
| output = output * self.s | |
| return output | |
| class MultiMarginProduct(nn.Module): | |
| def __init__(self, in_feature=128, out_feature=10575, s=32.0, m1=0.20, m2=0.35, easy_margin=False): | |
| super(MultiMarginProduct, self).__init__() | |
| self.in_feature = in_feature | |
| self.out_feature = out_feature | |
| self.s = s | |
| self.m1 = m1 | |
| self.m2 = m2 | |
| self.weight = Parameter(torch.Tensor(out_feature, in_feature)) | |
| nn.init.xavier_uniform_(self.weight) | |
| self.easy_margin = easy_margin | |
| self.cos_m1 = math.cos(m1) | |
| self.sin_m1 = math.sin(m1) | |
| # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] | |
| self.th = math.cos(math.pi - m1) | |
| self.mm = math.sin(math.pi - m1) * m1 | |
| def forward(self, x, label): | |
| # cos(theta) | |
| cosine = F.linear(F.normalize(x), F.normalize(self.weight)) | |
| # cos(theta + m1) | |
| sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) | |
| phi = cosine * self.cos_m1 - sine * self.sin_m1 | |
| if self.easy_margin: | |
| phi = torch.where(cosine > 0, phi, cosine) | |
| else: | |
| phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) | |
| one_hot = torch.zeros_like(cosine) | |
| one_hot.scatter_(1, label.view(-1, 1), 1) | |
| output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # additive angular margin | |
| output = output - one_hot * self.m2 # additive cosine margin | |
| output = output * self.s | |
| return output | |
| class CPDis(nn.Module): | |
| """PatchGAN.""" | |
| def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'): | |
| super(CPDis, self).__init__() | |
| layers = [] | |
| if norm == 'SN': | |
| layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))) | |
| else: | |
| layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) | |
| layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
| curr_dim = conv_dim | |
| for i in range(1, repeat_num): | |
| if norm == 'SN': | |
| layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))) | |
| else: | |
| layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)) | |
| layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
| curr_dim = curr_dim * 2 | |
| # k_size = int(image_size / np.power(2, repeat_num)) | |
| if norm == 'SN': | |
| layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))) | |
| else: | |
| layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)) | |
| layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
| curr_dim = curr_dim * 2 | |
| self.main = nn.Sequential(*layers) | |
| if norm == 'SN': | |
| self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)) | |
| else: | |
| self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False) | |
| def forward(self, x): | |
| if x.ndim == 5: | |
| x = x.squeeze(0) | |
| assert x.ndim == 4, x.ndim | |
| h = self.main(x) | |
| # out_real = self.conv1(h) | |
| out_makeup = self.conv1(h) | |
| # return out_real.squeeze(), out_makeup.squeeze() | |
| return out_makeup | |
| class CPDis_cls(nn.Module): | |
| """PatchGAN.""" | |
| def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'): | |
| super(CPDis_cls, self).__init__() | |
| layers = [] | |
| if norm == 'SN': | |
| layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))) | |
| else: | |
| layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) | |
| layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
| curr_dim = conv_dim | |
| for i in range(1, repeat_num): | |
| if norm == 'SN': | |
| layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))) | |
| else: | |
| layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)) | |
| layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
| curr_dim = curr_dim * 2 | |
| # k_size = int(image_size / np.power(2, repeat_num)) | |
| if norm == 'SN': | |
| layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))) | |
| else: | |
| layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)) | |
| layers.append(nn.LeakyReLU(0.01, inplace=True)) | |
| curr_dim = curr_dim * 2 | |
| self.main = nn.Sequential(*layers) | |
| if norm == 'SN': | |
| self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)) | |
| self.classifier_pool = nn.AdaptiveAvgPool2d(1) | |
| self.classifier_conv = nn.Conv2d(512, 512, 1, 1, 0) | |
| self.classifier = MarginCosineProduct(512,7)#ArcMarginProduct(512, 7) | |
| print("Using Large Margin Cosine Loss.") | |
| else: | |
| self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False) | |
| def forward(self, x, label): | |
| if x.ndim == 5: | |
| x = x.squeeze(0) | |
| assert x.ndim == 4, x.ndim | |
| h = self.main(x) # ([1, 512, 31, 31]) | |
| #print(out_cls.shape) | |
| out_cls = self.classifier_pool(h) | |
| #print(out_cls.shape) | |
| out_cls = self.classifier_conv(out_cls) | |
| #print(out_cls.shape) | |
| out_cls = torch.squeeze(out_cls, -1) | |
| out_cls = torch.squeeze(out_cls, -1) | |
| out_cls = self.classifier(out_cls, label) | |
| out_makeup = self.conv1(h) # torch.Size([1, 1, 30, 30]) | |
| # return out_real.squeeze(), out_makeup.squeeze() | |
| return out_makeup, out_cls | |
| class SpectralNorm(object): | |
| def __init__(self): | |
| self.name = "weight" | |
| # print(self.name) | |
| self.power_iterations = 1 | |
| def compute_weight(self, module): | |
| u = getattr(module, self.name + "_u") | |
| v = getattr(module, self.name + "_v") | |
| w = getattr(module, self.name + "_bar") | |
| height = w.data.shape[0] | |
| for _ in range(self.power_iterations): | |
| v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) | |
| u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) | |
| # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) | |
| sigma = u.dot(w.view(height, -1).mv(v)) | |
| return w / sigma.expand_as(w) | |
| def apply(module): | |
| name = "weight" | |
| fn = SpectralNorm() | |
| try: | |
| u = getattr(module, name + "_u") | |
| v = getattr(module, name + "_v") | |
| w = getattr(module, name + "_bar") | |
| except AttributeError: | |
| w = getattr(module, name) | |
| height = w.data.shape[0] | |
| width = w.view(height, -1).data.shape[1] | |
| u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) | |
| v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) | |
| w_bar = Parameter(w.data) | |
| # del module._parameters[name] | |
| module.register_parameter(name + "_u", u) | |
| module.register_parameter(name + "_v", v) | |
| module.register_parameter(name + "_bar", w_bar) | |
| # remove w from parameter list | |
| del module._parameters[name] | |
| setattr(module, name, fn.compute_weight(module)) | |
| # recompute weight before every forward() | |
| module.register_forward_pre_hook(fn) | |
| return fn | |
| def remove(self, module): | |
| weight = self.compute_weight(module) | |
| delattr(module, self.name) | |
| del module._parameters[self.name + '_u'] | |
| del module._parameters[self.name + '_v'] | |
| del module._parameters[self.name + '_bar'] | |
| module.register_parameter(self.name, Parameter(weight.data)) | |
| def __call__(self, module, inputs): | |
| setattr(module, self.name, self.compute_weight(module)) | |
| def spectral_norm(module): | |
| SpectralNorm.apply(module) | |
| return module | |
| def remove_spectral_norm(module): | |
| name = 'weight' | |
| for k, hook in module._forward_pre_hooks.items(): | |
| if isinstance(hook, SpectralNorm) and hook.name == name: | |
| hook.remove(module) | |
| del module._forward_pre_hooks[k] | |
| return module | |
| raise ValueError("spectral_norm of '{}' not found in {}" | |
| .format(name, module)) | |