Spaces:
Running
on
Zero
Running
on
Zero
| # | |
| # For licensing see accompanying LICENSE file. | |
| # Copyright (C) 2024 Apple Inc. All Rights Reserved. | |
| # | |
| from typing import Union, Tuple | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| __all__ = ["MobileOneBlock", "reparameterize_model"] | |
| class SEBlock(nn.Module): | |
| """Squeeze and Excite module. | |
| Pytorch implementation of `Squeeze-and-Excitation Networks` - | |
| https://arxiv.org/pdf/1709.01507.pdf | |
| """ | |
| def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None: | |
| """Construct a Squeeze and Excite Module. | |
| Args: | |
| in_channels: Number of input channels. | |
| rd_ratio: Input channel reduction ratio. | |
| """ | |
| super(SEBlock, self).__init__() | |
| self.reduce = nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=int(in_channels * rd_ratio), | |
| kernel_size=1, | |
| stride=1, | |
| bias=True, | |
| ) | |
| self.expand = nn.Conv2d( | |
| in_channels=int(in_channels * rd_ratio), | |
| out_channels=in_channels, | |
| kernel_size=1, | |
| stride=1, | |
| bias=True, | |
| ) | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| """Apply forward pass.""" | |
| b, c, h, w = inputs.size() | |
| x = F.avg_pool2d(inputs, kernel_size=[h, w]) | |
| x = self.reduce(x) | |
| x = F.relu(x) | |
| x = self.expand(x) | |
| x = torch.sigmoid(x) | |
| x = x.view(-1, c, 1, 1) | |
| return inputs * x | |
| class MobileOneBlock(nn.Module): | |
| """MobileOne building block. | |
| This block has a multi-branched architecture at train-time | |
| and plain-CNN style architecture at inference time | |
| For more details, please refer to our paper: | |
| `An Improved One millisecond Mobile Backbone` - | |
| https://arxiv.org/pdf/2206.04040.pdf | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| padding: int = 0, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| inference_mode: bool = False, | |
| use_se: bool = False, | |
| use_act: bool = True, | |
| use_scale_branch: bool = True, | |
| num_conv_branches: int = 1, | |
| activation: nn.Module = nn.GELU(), | |
| ) -> None: | |
| """Construct a MobileOneBlock module. | |
| Args: | |
| in_channels: Number of channels in the input. | |
| out_channels: Number of channels produced by the block. | |
| kernel_size: Size of the convolution kernel. | |
| stride: Stride size. | |
| padding: Zero-padding size. | |
| dilation: Kernel dilation factor. | |
| groups: Group number. | |
| inference_mode: If True, instantiates model in inference mode. | |
| use_se: Whether to use SE-ReLU activations. | |
| use_act: Whether to use activation. Default: ``True`` | |
| use_scale_branch: Whether to use scale branch. Default: ``True`` | |
| num_conv_branches: Number of linear conv branches. | |
| """ | |
| super(MobileOneBlock, self).__init__() | |
| self.inference_mode = inference_mode | |
| self.groups = groups | |
| self.stride = stride | |
| self.padding = padding | |
| self.dilation = dilation | |
| self.kernel_size = kernel_size | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.num_conv_branches = num_conv_branches | |
| # Check if SE-ReLU is requested | |
| if use_se: | |
| self.se = SEBlock(out_channels) | |
| else: | |
| self.se = nn.Identity() | |
| if use_act: | |
| self.activation = activation | |
| else: | |
| self.activation = nn.Identity() | |
| if inference_mode: | |
| self.reparam_conv = nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=True, | |
| ) | |
| else: | |
| # Re-parameterizable skip connection | |
| self.rbr_skip = ( | |
| nn.BatchNorm2d(num_features=in_channels) | |
| if out_channels == in_channels and stride == 1 | |
| else None | |
| ) | |
| # Re-parameterizable conv branches | |
| if num_conv_branches > 0: | |
| rbr_conv = list() | |
| for _ in range(self.num_conv_branches): | |
| rbr_conv.append( | |
| self._conv_bn(kernel_size=kernel_size, padding=padding) | |
| ) | |
| self.rbr_conv = nn.ModuleList(rbr_conv) | |
| else: | |
| self.rbr_conv = None | |
| # Re-parameterizable scale branch | |
| self.rbr_scale = None | |
| if not isinstance(kernel_size, int): | |
| kernel_size = kernel_size[0] | |
| if (kernel_size > 1) and use_scale_branch: | |
| self.rbr_scale = self._conv_bn(kernel_size=1, padding=0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Apply forward pass.""" | |
| # Inference mode forward pass. | |
| if self.inference_mode: | |
| return self.activation(self.se(self.reparam_conv(x))) | |
| # Multi-branched train-time forward pass. | |
| # Skip branch output | |
| identity_out = 0 | |
| if self.rbr_skip is not None: | |
| identity_out = self.rbr_skip(x) | |
| # Scale branch output | |
| scale_out = 0 | |
| if self.rbr_scale is not None: | |
| scale_out = self.rbr_scale(x) | |
| # Other branches | |
| out = scale_out + identity_out | |
| if self.rbr_conv is not None: | |
| for ix in range(self.num_conv_branches): | |
| out += self.rbr_conv[ix](x) | |
| return self.activation(self.se(out)) | |
| def reparameterize(self): | |
| """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - | |
| https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched | |
| architecture used at training time to obtain a plain CNN-like structure | |
| for inference. | |
| """ | |
| if self.inference_mode: | |
| return | |
| kernel, bias = self._get_kernel_bias() | |
| self.reparam_conv = nn.Conv2d( | |
| in_channels=self.in_channels, | |
| out_channels=self.out_channels, | |
| kernel_size=self.kernel_size, | |
| stride=self.stride, | |
| padding=self.padding, | |
| dilation=self.dilation, | |
| groups=self.groups, | |
| bias=True, | |
| ) | |
| self.reparam_conv.weight.data = kernel | |
| self.reparam_conv.bias.data = bias | |
| # Delete un-used branches | |
| for para in self.parameters(): | |
| para.detach_() | |
| self.__delattr__("rbr_conv") | |
| self.__delattr__("rbr_scale") | |
| if hasattr(self, "rbr_skip"): | |
| self.__delattr__("rbr_skip") | |
| self.inference_mode = True | |
| def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Method to obtain re-parameterized kernel and bias. | |
| Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 | |
| Returns: | |
| Tuple of (kernel, bias) after fusing branches. | |
| """ | |
| # get weights and bias of scale branch | |
| kernel_scale = 0 | |
| bias_scale = 0 | |
| if self.rbr_scale is not None: | |
| kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) | |
| # Pad scale branch kernel to match conv branch kernel size. | |
| pad = self.kernel_size // 2 | |
| kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) | |
| # get weights and bias of skip branch | |
| kernel_identity = 0 | |
| bias_identity = 0 | |
| if self.rbr_skip is not None: | |
| kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) | |
| # get weights and bias of conv branches | |
| kernel_conv = 0 | |
| bias_conv = 0 | |
| if self.rbr_conv is not None: | |
| for ix in range(self.num_conv_branches): | |
| _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) | |
| kernel_conv += _kernel | |
| bias_conv += _bias | |
| kernel_final = kernel_conv + kernel_scale + kernel_identity | |
| bias_final = bias_conv + bias_scale + bias_identity | |
| return kernel_final, bias_final | |
| def _fuse_bn_tensor( | |
| self, branch: Union[nn.Sequential, nn.BatchNorm2d] | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Method to fuse batchnorm layer with preceeding conv layer. | |
| Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 | |
| Args: | |
| branch: Sequence of ops to be fused. | |
| Returns: | |
| Tuple of (kernel, bias) after fusing batchnorm. | |
| """ | |
| if isinstance(branch, nn.Sequential): | |
| kernel = branch.conv.weight | |
| running_mean = branch.bn.running_mean | |
| running_var = branch.bn.running_var | |
| gamma = branch.bn.weight | |
| beta = branch.bn.bias | |
| eps = branch.bn.eps | |
| else: | |
| assert isinstance(branch, nn.BatchNorm2d) | |
| if not hasattr(self, "id_tensor"): | |
| input_dim = self.in_channels // self.groups | |
| kernel_size = self.kernel_size | |
| if isinstance(self.kernel_size, int): | |
| kernel_size = (self.kernel_size, self.kernel_size) | |
| kernel_value = torch.zeros( | |
| (self.in_channels, input_dim, kernel_size[0], kernel_size[1]), | |
| dtype=branch.weight.dtype, | |
| device=branch.weight.device, | |
| ) | |
| for i in range(self.in_channels): | |
| kernel_value[ | |
| i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2 | |
| ] = 1 | |
| self.id_tensor = kernel_value | |
| kernel = self.id_tensor | |
| running_mean = branch.running_mean | |
| running_var = branch.running_var | |
| gamma = branch.weight | |
| beta = branch.bias | |
| eps = branch.eps | |
| std = (running_var + eps).sqrt() | |
| t = (gamma / std).reshape(-1, 1, 1, 1) | |
| return kernel * t, beta - running_mean * gamma / std | |
| def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential: | |
| """Helper method to construct conv-batchnorm layers. | |
| Args: | |
| kernel_size: Size of the convolution kernel. | |
| padding: Zero-padding size. | |
| Returns: | |
| Conv-BN module. | |
| """ | |
| mod_list = nn.Sequential() | |
| mod_list.add_module( | |
| "conv", | |
| nn.Conv2d( | |
| in_channels=self.in_channels, | |
| out_channels=self.out_channels, | |
| kernel_size=kernel_size, | |
| stride=self.stride, | |
| padding=padding, | |
| groups=self.groups, | |
| bias=False, | |
| ), | |
| ) | |
| mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) | |
| return mod_list | |
| def reparameterize_model(model: torch.nn.Module) -> nn.Module: | |
| """Method returns a model where a multi-branched structure | |
| used in training is re-parameterized into a single branch | |
| for inference. | |
| Args: | |
| model: MobileOne model in train mode. | |
| Returns: | |
| MobileOne model in inference mode. | |
| """ | |
| # Avoid editing original graph | |
| model = copy.deepcopy(model) | |
| for module in model.modules(): | |
| if hasattr(module, "reparameterize"): | |
| module.reparameterize() | |
| return model |