Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| from functools import partial | |
| import torch | |
| import torch.nn as nn | |
| from imaginaire.layers import Conv2dBlock | |
| class NonLocal2dBlock(nn.Module): | |
| r"""Self attention Layer | |
| Args: | |
| in_channels (int): Number of channels in the input tensor. | |
| scale (bool, optional, default=True): If ``True``, scale the | |
| output by a learnable parameter. | |
| clamp (bool, optional, default=``False``): If ``True``, clamp the | |
| scaling parameter to (-1, 1). | |
| weight_norm_type (str, optional, default='none'): | |
| Type of weight normalization. | |
| ``'none'``, ``'spectral'``, ``'weight'``. | |
| weight_norm_params (obj, optional, default=None): | |
| Parameters of weight normalization. | |
| If not ``None``, weight_norm_params.__dict__ will be used as | |
| keyword arguments when initializing weight normalization. | |
| bias (bool, optional, default=True): If ``True``, adds bias in the | |
| convolutional blocks. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| scale=True, | |
| clamp=False, | |
| weight_norm_type='none', | |
| weight_norm_params=None, | |
| bias=True): | |
| super(NonLocal2dBlock, self).__init__() | |
| self.clamp = clamp | |
| self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0 | |
| self.in_channels = in_channels | |
| base_conv2d_block = partial(Conv2dBlock, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| weight_norm_type=weight_norm_type, | |
| weight_norm_params=weight_norm_params, | |
| bias=bias) | |
| self.theta = base_conv2d_block(in_channels, in_channels // 8) | |
| self.phi = base_conv2d_block(in_channels, in_channels // 8) | |
| self.g = base_conv2d_block(in_channels, in_channels // 2) | |
| self.out_conv = base_conv2d_block(in_channels // 2, in_channels) | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.max_pool = nn.MaxPool2d(2) | |
| def forward(self, x): | |
| r""" | |
| Args: | |
| x (tensor) : input feature maps (B X C X W X H) | |
| Returns: | |
| (tuple): | |
| - out (tensor) : self attention value + input feature | |
| - attention (tensor): B x N x N (N is Width*Height) | |
| """ | |
| n, c, h, w = x.size() | |
| theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1) | |
| phi = self.phi(x) | |
| phi = self.max_pool(phi).view(n, -1, h * w // 4) | |
| energy = torch.bmm(theta, phi) | |
| attention = self.softmax(energy) | |
| g = self.g(x) | |
| g = self.max_pool(g).view(n, -1, h * w // 4) | |
| out = torch.bmm(g, attention.permute(0, 2, 1)) | |
| out = out.view(n, c // 2, h, w) | |
| out = self.out_conv(out) | |
| if self.clamp: | |
| out = self.gamma.clamp(-1, 1) * out + x | |
| else: | |
| out = self.gamma * out + x | |
| return out | |