Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule | |
| from mmengine.model import BaseModule | |
| from mmocr.registry import MODELS | |
| from mmocr.utils import ConfigType, MultiConfig, OptConfigType | |
| class BiFPN(BaseModule): | |
| """illustration of a minimal bifpn unit P7_0 -------------------------> | |
| P7_2 --------> | |
| |-------------| β β | | |
| P6_0 ---------> P6_1 ---------> P6_2 --------> | |
| |-------------|--------------β β β | P5_0 | |
| ---------> P5_1 ---------> P5_2 --------> |-------------|--------------β | |
| β β | P4_0 ---------> P4_1 ---------> P4_2 | |
| --------> |-------------|--------------β β | |
| |--------------β | P3_0 -------------------------> P3_2 --------> | |
| """ | |
| def __init__(self, | |
| in_channels: List[int], | |
| out_channels: int, | |
| num_outs: int, | |
| repeat_times: int = 2, | |
| start_level: int = 0, | |
| end_level: int = -1, | |
| add_extra_convs: bool = False, | |
| relu_before_extra_convs: bool = False, | |
| no_norm_on_lateral: bool = False, | |
| conv_cfg: OptConfigType = None, | |
| norm_cfg: OptConfigType = None, | |
| act_cfg: OptConfigType = None, | |
| laterial_conv1x1: bool = False, | |
| upsample_cfg: ConfigType = dict(mode='nearest'), | |
| pool_cfg: ConfigType = dict(), | |
| init_cfg: MultiConfig = dict( | |
| type='Xavier', layer='Conv2d', distribution='uniform')): | |
| super().__init__(init_cfg=init_cfg) | |
| assert isinstance(in_channels, list) | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.num_ins = len(in_channels) | |
| self.num_outs = num_outs | |
| self.relu_before_extra_convs = relu_before_extra_convs | |
| self.no_norm_on_lateral = no_norm_on_lateral | |
| self.upsample_cfg = upsample_cfg.copy() | |
| self.repeat_times = repeat_times | |
| if end_level == -1 or end_level == self.num_ins - 1: | |
| self.backbone_end_level = self.num_ins | |
| assert num_outs >= self.num_ins - start_level | |
| else: | |
| # if end_level is not the last level, no extra level is allowed | |
| self.backbone_end_level = end_level + 1 | |
| assert end_level < self.num_ins | |
| assert num_outs == end_level - start_level + 1 | |
| self.start_level = start_level | |
| self.end_level = end_level | |
| self.add_extra_convs = add_extra_convs | |
| self.lateral_convs = nn.ModuleList() | |
| self.extra_convs = nn.ModuleList() | |
| self.bifpn_convs = nn.ModuleList() | |
| for i in range(self.start_level, self.backbone_end_level): | |
| if in_channels[i] == out_channels: | |
| l_conv = nn.Identity() | |
| else: | |
| l_conv = ConvModule( | |
| in_channels[i], | |
| out_channels, | |
| 1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| bias=True, | |
| act_cfg=act_cfg, | |
| inplace=False) | |
| self.lateral_convs.append(l_conv) | |
| for _ in range(repeat_times): | |
| self.bifpn_convs.append( | |
| BiFPNLayer( | |
| channels=out_channels, | |
| levels=num_outs, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| pool_cfg=pool_cfg)) | |
| # add extra conv layers (e.g., RetinaNet) | |
| extra_levels = num_outs - self.backbone_end_level + self.start_level | |
| if add_extra_convs and extra_levels >= 1: | |
| for i in range(extra_levels): | |
| if i == 0: | |
| in_channels = self.in_channels[self.backbone_end_level - 1] | |
| else: | |
| in_channels = out_channels | |
| if in_channels == out_channels: | |
| extra_fpn_conv = nn.MaxPool2d( | |
| kernel_size=3, stride=2, padding=1) | |
| else: | |
| extra_fpn_conv = nn.Sequential( | |
| ConvModule( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg), | |
| nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) | |
| self.extra_convs.append(extra_fpn_conv) | |
| def forward(self, inputs): | |
| def extra_convs(inputs, extra_convs): | |
| outputs = list() | |
| for extra_conv in extra_convs: | |
| inputs = extra_conv(inputs) | |
| outputs.append(inputs) | |
| return outputs | |
| assert len(inputs) == len(self.in_channels) | |
| # build laterals | |
| laterals = [ | |
| lateral_conv(inputs[i + self.start_level]) | |
| for i, lateral_conv in enumerate(self.lateral_convs) | |
| ] | |
| if self.num_outs > len(laterals) and self.add_extra_convs: | |
| extra_source = inputs[self.backbone_end_level - 1] | |
| for extra_conv in self.extra_convs: | |
| extra_source = extra_conv(extra_source) | |
| laterals.append(extra_source) | |
| for bifpn_module in self.bifpn_convs: | |
| laterals = bifpn_module(laterals) | |
| outs = laterals | |
| return tuple(outs) | |
| def swish(x): | |
| return x * x.sigmoid() | |
| class BiFPNLayer(BaseModule): | |
| def __init__(self, | |
| channels, | |
| levels, | |
| init=0.5, | |
| conv_cfg=None, | |
| norm_cfg=None, | |
| act_cfg=None, | |
| upsample_cfg=None, | |
| pool_cfg=None, | |
| eps=0.0001, | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.act_cfg = act_cfg | |
| self.upsample_cfg = upsample_cfg | |
| self.pool_cfg = pool_cfg | |
| self.eps = eps | |
| self.levels = levels | |
| self.bifpn_convs = nn.ModuleList() | |
| # weighted | |
| self.weight_two_nodes = nn.Parameter( | |
| torch.Tensor(2, levels).fill_(init)) | |
| self.weight_three_nodes = nn.Parameter( | |
| torch.Tensor(3, levels - 2).fill_(init)) | |
| self.relu = nn.ReLU() | |
| for _ in range(2): | |
| for _ in range(self.levels - 1): # 1,2,3 | |
| fpn_conv = nn.Sequential( | |
| ConvModule( | |
| channels, | |
| channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| inplace=False)) | |
| self.bifpn_convs.append(fpn_conv) | |
| def forward(self, inputs): | |
| assert len(inputs) == self.levels | |
| # build top-down and down-top path with stack | |
| levels = self.levels | |
| # w relu | |
| w1 = self.relu(self.weight_two_nodes) | |
| w1 /= torch.sum(w1, dim=0) + self.eps # normalize | |
| w2 = self.relu(self.weight_three_nodes) | |
| # w2 /= torch.sum(w2, dim=0) + self.eps # normalize | |
| # build top-down | |
| idx_bifpn = 0 | |
| pathtd = inputs | |
| inputs_clone = [] | |
| for in_tensor in inputs: | |
| inputs_clone.append(in_tensor.clone()) | |
| for i in range(levels - 1, 0, -1): | |
| _, _, h, w = pathtd[i - 1].shape | |
| # pathtd[i - 1] = ( | |
| # w1[0, i - 1] * pathtd[i - 1] + w1[1, i - 1] * | |
| # F.interpolate(pathtd[i], size=(h, w), mode='nearest')) / ( | |
| # w1[0, i - 1] + w1[1, i - 1] + self.eps) | |
| pathtd[i - | |
| 1] = w1[0, i - | |
| 1] * pathtd[i - 1] + w1[1, i - 1] * F.interpolate( | |
| pathtd[i], size=(h, w), mode='nearest') | |
| pathtd[i - 1] = swish(pathtd[i - 1]) | |
| pathtd[i - 1] = self.bifpn_convs[idx_bifpn](pathtd[i - 1]) | |
| idx_bifpn = idx_bifpn + 1 | |
| # build down-top | |
| for i in range(0, levels - 2, 1): | |
| tmp_path = torch.stack([ | |
| inputs_clone[i + 1], pathtd[i + 1], | |
| F.max_pool2d(pathtd[i], kernel_size=3, stride=2, padding=1) | |
| ], | |
| dim=-1) | |
| norm_weight = w2[:, i] / (w2[:, i].sum() + self.eps) | |
| pathtd[i + 1] = (norm_weight * tmp_path).sum(dim=-1) | |
| # pathtd[i + 1] = w2[0, i] * inputs_clone[i + 1] | |
| # + w2[1, i] * pathtd[ | |
| # i + 1] + w2[2, i] * F.max_pool2d( | |
| # pathtd[i], kernel_size=3, stride=2, padding=1) | |
| pathtd[i + 1] = swish(pathtd[i + 1]) | |
| pathtd[i + 1] = self.bifpn_convs[idx_bifpn](pathtd[i + 1]) | |
| idx_bifpn = idx_bifpn + 1 | |
| pathtd[levels - 1] = w1[0, levels - 1] * pathtd[levels - 1] + w1[ | |
| 1, levels - 1] * F.max_pool2d( | |
| pathtd[levels - 2], kernel_size=3, stride=2, padding=1) | |
| pathtd[levels - 1] = swish(pathtd[levels - 1]) | |
| pathtd[levels - 1] = self.bifpn_convs[idx_bifpn](pathtd[levels - 1]) | |
| return pathtd | |