Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| from torchvision.ops import deform_conv2d | |
| class DeformableConv2d(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False): | |
| super(DeformableConv2d, self).__init__() | |
| assert type(kernel_size) == tuple or type(kernel_size) == int | |
| kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) | |
| self.stride = stride if type(stride) == tuple else (stride, stride) | |
| self.padding = padding | |
| self.offset_conv = nn.Conv2d(in_channels, | |
| 2 * kernel_size[0] * kernel_size[1], | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=self.padding, | |
| bias=True) | |
| nn.init.constant_(self.offset_conv.weight, 0.) | |
| nn.init.constant_(self.offset_conv.bias, 0.) | |
| self.modulator_conv = nn.Conv2d(in_channels, | |
| 1 * kernel_size[0] * kernel_size[1], | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=self.padding, | |
| bias=True) | |
| nn.init.constant_(self.modulator_conv.weight, 0.) | |
| nn.init.constant_(self.modulator_conv.bias, 0.) | |
| self.regular_conv = nn.Conv2d(in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=self.padding, | |
| bias=bias) | |
| def forward(self, x): | |
| #h, w = x.shape[2:] | |
| #max_offset = max(h, w)/4. | |
| offset = self.offset_conv(x)#.clamp(-max_offset, max_offset) | |
| modulator = 2. * torch.sigmoid(self.modulator_conv(x)) | |
| x = deform_conv2d( | |
| input=x, | |
| offset=offset, | |
| weight=self.regular_conv.weight, | |
| bias=self.regular_conv.bias, | |
| padding=self.padding, | |
| mask=modulator, | |
| stride=self.stride, | |
| ) | |
| return x | |