Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.nn.modules.utils import _pair | |
| from detectron2.layers.wrappers import _NewEmptyTensorOp | |
| class TridentConv(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| paddings=0, | |
| dilations=1, | |
| groups=1, | |
| num_branch=1, | |
| test_branch_idx=-1, | |
| bias=False, | |
| norm=None, | |
| activation=None, | |
| ): | |
| super(TridentConv, self).__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.kernel_size = _pair(kernel_size) | |
| self.num_branch = num_branch | |
| self.stride = _pair(stride) | |
| self.groups = groups | |
| self.with_bias = bias | |
| if isinstance(paddings, int): | |
| paddings = [paddings] * self.num_branch | |
| if isinstance(dilations, int): | |
| dilations = [dilations] * self.num_branch | |
| self.paddings = [_pair(padding) for padding in paddings] | |
| self.dilations = [_pair(dilation) for dilation in dilations] | |
| self.test_branch_idx = test_branch_idx | |
| self.norm = norm | |
| self.activation = activation | |
| assert len({self.num_branch, len(self.paddings), len(self.dilations)}) == 1 | |
| self.weight = nn.Parameter( | |
| torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) | |
| ) | |
| if bias: | |
| self.bias = nn.Parameter(torch.Tensor(out_channels)) | |
| else: | |
| self.bias = None | |
| nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") | |
| if self.bias is not None: | |
| nn.init.constant_(self.bias, 0) | |
| def forward(self, inputs): | |
| num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 | |
| assert len(inputs) == num_branch | |
| if inputs[0].numel() == 0: | |
| output_shape = [ | |
| (i + 2 * p - (di * (k - 1) + 1)) // s + 1 | |
| for i, p, di, k, s in zip( | |
| inputs[0].shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride | |
| ) | |
| ] | |
| output_shape = [input[0].shape[0], self.weight.shape[0]] + output_shape | |
| return [_NewEmptyTensorOp.apply(input, output_shape) for input in inputs] | |
| if self.training or self.test_branch_idx == -1: | |
| outputs = [ | |
| F.conv2d(input, self.weight, self.bias, self.stride, padding, dilation, self.groups) | |
| for input, dilation, padding in zip(inputs, self.dilations, self.paddings) | |
| ] | |
| else: | |
| outputs = [ | |
| F.conv2d( | |
| inputs[0], | |
| self.weight, | |
| self.bias, | |
| self.stride, | |
| self.paddings[self.test_branch_idx], | |
| self.dilations[self.test_branch_idx], | |
| self.groups, | |
| ) | |
| ] | |
| if self.norm is not None: | |
| outputs = [self.norm(x) for x in outputs] | |
| if self.activation is not None: | |
| outputs = [self.activation(x) for x in outputs] | |
| return outputs | |
| def extra_repr(self): | |
| tmpstr = "in_channels=" + str(self.in_channels) | |
| tmpstr += ", out_channels=" + str(self.out_channels) | |
| tmpstr += ", kernel_size=" + str(self.kernel_size) | |
| tmpstr += ", num_branch=" + str(self.num_branch) | |
| tmpstr += ", test_branch_idx=" + str(self.test_branch_idx) | |
| tmpstr += ", stride=" + str(self.stride) | |
| tmpstr += ", paddings=" + str(self.paddings) | |
| tmpstr += ", dilations=" + str(self.dilations) | |
| tmpstr += ", groups=" + str(self.groups) | |
| tmpstr += ", bias=" + str(self.with_bias) | |
| return tmpstr | |