Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| # adapted from PANNs (https://github.com/qiuqiangkong/audioset_tagging_cnn) | |
| def count_macs(model, spec_size): | |
| list_conv2d = [] | |
| def conv2d_hook(self, input, output): | |
| batch_size, input_channels, input_height, input_width = input[0].size() | |
| assert batch_size == 1 | |
| output_channels, output_height, output_width = output[0].size() | |
| kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) | |
| bias_ops = 1 if self.bias is not None else 0 | |
| params = output_channels * (kernel_ops + bias_ops) | |
| # overall macs count is: | |
| # kernel**2 * in_channels/groups * out_channels * out_width * out_height | |
| macs = batch_size * params * output_height * output_width | |
| list_conv2d.append(macs) | |
| list_linear = [] | |
| def linear_hook(self, input, output): | |
| batch_size = input[0].size(0) if input[0].dim() == 2 else 1 | |
| assert batch_size == 1 | |
| weight_ops = self.weight.nelement() | |
| bias_ops = self.bias.nelement() | |
| # overall macs count is equal to the number of parameters in layer | |
| macs = batch_size * (weight_ops + bias_ops) | |
| list_linear.append(macs) | |
| def foo(net): | |
| if net.__class__.__name__ == 'Conv2dStaticSamePadding': | |
| net.register_forward_hook(conv2d_hook) | |
| childrens = list(net.children()) | |
| if not childrens: | |
| if isinstance(net, nn.Conv2d): | |
| net.register_forward_hook(conv2d_hook) | |
| elif isinstance(net, nn.Linear): | |
| net.register_forward_hook(linear_hook) | |
| else: | |
| print('Warning: flop of module {} is not counted!'.format(net)) | |
| return | |
| for c in childrens: | |
| foo(c) | |
| # Register hook | |
| foo(model) | |
| device = next(model.parameters()).device | |
| input = torch.rand(spec_size).to(device) | |
| with torch.no_grad(): | |
| model(input) | |
| total_macs = sum(list_conv2d) + sum(list_linear) | |
| print("*************Computational Complexity (multiply-adds) **************") | |
| print("Number of Convolutional Layers: ", len(list_conv2d)) | |
| print("Number of Linear Layers: ", len(list_linear)) | |
| print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs))) | |
| print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs)) | |
| print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9)) | |
| print("********************************************************************") | |
| return total_macs | |
| def count_macs_transformer(model, spec_size): | |
| """Count macs. Code modified from others' implementation. | |
| """ | |
| list_conv2d = [] | |
| def conv2d_hook(self, input, output): | |
| batch_size, input_channels, input_height, input_width = input[0].size() | |
| assert batch_size == 1 | |
| output_channels, output_height, output_width = output[0].size() | |
| kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) | |
| bias_ops = 1 if self.bias is not None else 0 | |
| params = output_channels * (kernel_ops + bias_ops) | |
| # overall macs count is: | |
| # kernel**2 * in_channels/groups * out_channels * out_width * out_height | |
| macs = batch_size * params * output_height * output_width | |
| list_conv2d.append(macs) | |
| list_linear = [] | |
| def linear_hook(self, input, output): | |
| batch_size = input[0].size(0) if input[0].dim() >= 2 else 1 | |
| assert batch_size == 1 | |
| if input[0].dim() == 3: | |
| # (batch size, sequence length, embeddings size) | |
| batch_size, seq_len, embed_size = input[0].size() | |
| weight_ops = self.weight.nelement() | |
| bias_ops = self.bias.nelement() if self.bias is not None else 0 | |
| # linear layer applied position-wise, multiply with sequence length | |
| macs = batch_size * (weight_ops + bias_ops) * seq_len | |
| else: | |
| # classification head | |
| # (batch size, embeddings size) | |
| batch_size, embed_size = input[0].size() | |
| weight_ops = self.weight.nelement() | |
| bias_ops = self.bias.nelement() if self.bias is not None else 0 | |
| # overall macs count is equal to the number of parameters in layer | |
| macs = batch_size * (weight_ops + bias_ops) | |
| list_linear.append(macs) | |
| list_att = [] | |
| def attention_hook(self, input, output): | |
| # here we only calculate the attention macs; linear layers are processed in linear_hook | |
| batch_size, seq_len, embed_size = input[0].size() | |
| # 2 times embed_size * seq_len**2 | |
| # - computing the attention matrix: embed_size * seq_len**2 | |
| # - multiply attention matrix with value matrix: embed_size * seq_len**2 | |
| macs = batch_size * embed_size * seq_len * seq_len * 2 | |
| list_att.append(macs) | |
| def foo(net): | |
| childrens = list(net.children()) | |
| if net.__class__.__name__ == "MultiHeadAttention": | |
| net.register_forward_hook(attention_hook) | |
| if not childrens: | |
| if isinstance(net, nn.Conv2d): | |
| net.register_forward_hook(conv2d_hook) | |
| elif isinstance(net, nn.Linear): | |
| net.register_forward_hook(linear_hook) | |
| else: | |
| print('Warning: flop of module {} is not counted!'.format(net)) | |
| return | |
| for c in childrens: | |
| foo(c) | |
| # Register hook | |
| foo(model) | |
| device = next(model.parameters()).device | |
| input = torch.rand(spec_size).to(device) | |
| with torch.no_grad(): | |
| model(input) | |
| total_macs = sum(list_conv2d) + sum(list_linear) + sum(list_att) | |
| print("*************Computational Complexity (multiply-adds) **************") | |
| print("Number of Convolutional Layers: ", len(list_conv2d)) | |
| print("Number of Linear Layers: ", len(list_linear)) | |
| print("Number of Attention Layers: ", len(list_att)) | |
| print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs))) | |
| print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs)) | |
| print("Relative Share of Attention Layers: {:.2f}".format(sum(list_att) / total_macs)) | |
| print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9)) | |
| print("********************************************************************") | |
| return total_macs | |