Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from torchvision.models import resnet | |
| from typing import Optional, Callable | |
| class ConvBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| gate: Optional[Callable[..., nn.Module]] = None, | |
| norm_layer: Optional[Callable[..., nn.Module]] = None, | |
| ): | |
| super().__init__() | |
| if gate is None: | |
| self.gate = nn.ReLU(inplace=True) | |
| else: | |
| self.gate = gate | |
| if norm_layer is None: | |
| norm_layer = nn.BatchNorm2d | |
| self.conv1 = resnet.conv3x3(in_channels, out_channels) | |
| self.bn1 = norm_layer(out_channels) | |
| self.conv2 = resnet.conv3x3(out_channels, out_channels) | |
| self.bn2 = norm_layer(out_channels) | |
| def forward(self, x): | |
| x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W | |
| x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W | |
| return x | |
| # copied from torchvision\models\resnet.py#27->BasicBlock | |
| class ResBlock(nn.Module): | |
| expansion: int = 1 | |
| def __init__( | |
| self, | |
| inplanes: int, | |
| planes: int, | |
| stride: int = 1, | |
| downsample: Optional[nn.Module] = None, | |
| groups: int = 1, | |
| base_width: int = 64, | |
| dilation: int = 1, | |
| gate: Optional[Callable[..., nn.Module]] = None, | |
| norm_layer: Optional[Callable[..., nn.Module]] = None, | |
| ) -> None: | |
| super(ResBlock, self).__init__() | |
| if gate is None: | |
| self.gate = nn.ReLU(inplace=True) | |
| else: | |
| self.gate = gate | |
| if norm_layer is None: | |
| norm_layer = nn.BatchNorm2d | |
| if groups != 1 or base_width != 64: | |
| raise ValueError("ResBlock only supports groups=1 and base_width=64") | |
| if dilation > 1: | |
| raise NotImplementedError("Dilation > 1 not supported in ResBlock") | |
| # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |
| self.conv1 = resnet.conv3x3(inplanes, planes, stride) | |
| self.bn1 = norm_layer(planes) | |
| self.conv2 = resnet.conv3x3(planes, planes) | |
| self.bn2 = norm_layer(planes) | |
| self.downsample = downsample | |
| self.stride = stride | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| identity = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.gate(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| if self.downsample is not None: | |
| identity = self.downsample(x) | |
| out += identity | |
| out = self.gate(out) | |
| return out | |
| class ALNet(nn.Module): | |
| def __init__( | |
| self, | |
| c1: int = 32, | |
| c2: int = 64, | |
| c3: int = 128, | |
| c4: int = 128, | |
| dim: int = 128, | |
| single_head: bool = True, | |
| ): | |
| super().__init__() | |
| self.gate = nn.ReLU(inplace=True) | |
| self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.pool4 = nn.MaxPool2d(kernel_size=4, stride=4) | |
| self.block1 = ConvBlock(3, c1, self.gate, nn.BatchNorm2d) | |
| self.block2 = ResBlock( | |
| inplanes=c1, | |
| planes=c2, | |
| stride=1, | |
| downsample=nn.Conv2d(c1, c2, 1), | |
| gate=self.gate, | |
| norm_layer=nn.BatchNorm2d, | |
| ) | |
| self.block3 = ResBlock( | |
| inplanes=c2, | |
| planes=c3, | |
| stride=1, | |
| downsample=nn.Conv2d(c2, c3, 1), | |
| gate=self.gate, | |
| norm_layer=nn.BatchNorm2d, | |
| ) | |
| self.block4 = ResBlock( | |
| inplanes=c3, | |
| planes=c4, | |
| stride=1, | |
| downsample=nn.Conv2d(c3, c4, 1), | |
| gate=self.gate, | |
| norm_layer=nn.BatchNorm2d, | |
| ) | |
| # ================================== feature aggregation | |
| self.conv1 = resnet.conv1x1(c1, dim // 4) | |
| self.conv2 = resnet.conv1x1(c2, dim // 4) | |
| self.conv3 = resnet.conv1x1(c3, dim // 4) | |
| self.conv4 = resnet.conv1x1(dim, dim // 4) | |
| self.upsample2 = nn.Upsample( | |
| scale_factor=2, mode="bilinear", align_corners=True | |
| ) | |
| self.upsample4 = nn.Upsample( | |
| scale_factor=4, mode="bilinear", align_corners=True | |
| ) | |
| self.upsample8 = nn.Upsample( | |
| scale_factor=8, mode="bilinear", align_corners=True | |
| ) | |
| self.upsample32 = nn.Upsample( | |
| scale_factor=32, mode="bilinear", align_corners=True | |
| ) | |
| # ================================== detector and descriptor head | |
| self.single_head = single_head | |
| if not self.single_head: | |
| self.convhead1 = resnet.conv1x1(dim, dim) | |
| self.convhead2 = resnet.conv1x1(dim, dim + 1) | |
| def forward(self, image): | |
| # ================================== feature encoder | |
| x1 = self.block1(image) # B x c1 x H x W | |
| x2 = self.pool2(x1) | |
| x2 = self.block2(x2) # B x c2 x H/2 x W/2 | |
| x3 = self.pool4(x2) | |
| x3 = self.block3(x3) # B x c3 x H/8 x W/8 | |
| x4 = self.pool4(x3) | |
| x4 = self.block4(x4) # B x dim x H/32 x W/32 | |
| # ================================== feature aggregation | |
| x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W | |
| x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2 | |
| x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8 | |
| x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32 | |
| x2_up = self.upsample2(x2) # B x dim//4 x H x W | |
| x3_up = self.upsample8(x3) # B x dim//4 x H x W | |
| x4_up = self.upsample32(x4) # B x dim//4 x H x W | |
| x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1) | |
| # ================================== detector and descriptor head | |
| if not self.single_head: | |
| x1234 = self.gate(self.convhead1(x1234)) | |
| x = self.convhead2(x1234) # B x dim+1 x H x W | |
| descriptor_map = x[:, :-1, :, :] | |
| scores_map = torch.sigmoid(x[:, -1, :, :]).unsqueeze(1) | |
| return scores_map, descriptor_map | |
| if __name__ == "__main__": | |
| from thop import profile | |
| net = ALNet(c1=16, c2=32, c3=64, c4=128, dim=128, single_head=True) | |
| image = torch.randn(1, 3, 640, 480) | |
| flops, params = profile(net, inputs=(image,), verbose=False) | |
| print("{:<30} {:<8} GFLops".format("Computational complexity: ", flops / 1e9)) | |
| print("{:<30} {:<8} KB".format("Number of parameters: ", params / 1e3)) | |