Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from thop import profile | |
| from copy import deepcopy | |
| __all__ = [ | |
| "fuse_conv_and_bn", | |
| "fuse_model", | |
| "get_model_info", | |
| "replace_module", | |
| ] | |
| def get_model_info(model, tsize): | |
| stride = 64 | |
| img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) | |
| flops, params = profile(deepcopy(model), inputs=(img,), verbose=False) | |
| params /= 1e6 | |
| flops /= 1e9 | |
| flops *= tsize[0] * tsize[1] / stride / stride * 2 # Gflops | |
| info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops) | |
| return info | |
| def fuse_conv_and_bn(conv, bn): | |
| # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ | |
| fusedconv = ( | |
| nn.Conv2d( | |
| conv.in_channels, | |
| conv.out_channels, | |
| kernel_size=conv.kernel_size, | |
| stride=conv.stride, | |
| padding=conv.padding, | |
| groups=conv.groups, | |
| bias=True, | |
| ) | |
| .requires_grad_(False) | |
| .to(conv.weight.device) | |
| ) | |
| # prepare filters | |
| w_conv = conv.weight.clone().view(conv.out_channels, -1) | |
| w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) | |
| fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) | |
| # prepare spatial bias | |
| b_conv = ( | |
| torch.zeros(conv.weight.size(0), device=conv.weight.device) | |
| if conv.bias is None | |
| else conv.bias | |
| ) | |
| b_bn = bn.bias - bn.weight.mul(bn.running_mean).div( | |
| torch.sqrt(bn.running_var + bn.eps) | |
| ) | |
| fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) | |
| return fusedconv | |
| def fuse_model(model): | |
| from yolox.models.network_blocks import BaseConv | |
| for m in model.modules(): | |
| if type(m) is BaseConv and hasattr(m, "bn"): | |
| m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv | |
| delattr(m, "bn") # remove batchnorm | |
| m.forward = m.fuseforward # update forward | |
| return model | |
| def replace_module(module, replaced_module_type, new_module_type, replace_func=None): | |
| """ | |
| Replace given type in module to a new type. mostly used in deploy. | |
| Args: | |
| module (nn.Module): model to apply replace operation. | |
| replaced_module_type (Type): module type to be replaced. | |
| new_module_type (Type) | |
| replace_func (function): python function to describe replace logic. Defalut value None. | |
| Returns: | |
| model (nn.Module): module that already been replaced. | |
| """ | |
| def default_replace_func(replaced_module_type, new_module_type): | |
| return new_module_type() | |
| if replace_func is None: | |
| replace_func = default_replace_func | |
| model = module | |
| if isinstance(module, replaced_module_type): | |
| model = replace_func(replaced_module_type, new_module_type) | |
| else: # recurrsively replace | |
| for name, child in module.named_children(): | |
| new_child = replace_module(child, replaced_module_type, new_module_type) | |
| if new_child is not child: # child is already replaced | |
| model.add_module(name, new_child) | |
| return model | |