aps's picture
Commit efficientat
4848335
raw
history blame
6.5 kB
import torch
import torch.nn as nn
# adapted from PANNs (https://github.com/qiuqiangkong/audioset_tagging_cnn)
def count_macs(model, spec_size):
list_conv2d = []
def conv2d_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
assert batch_size == 1
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
bias_ops = 1 if self.bias is not None else 0
params = output_channels * (kernel_ops + bias_ops)
# overall macs count is:
# kernel**2 * in_channels/groups * out_channels * out_width * out_height
macs = batch_size * params * output_height * output_width
list_conv2d.append(macs)
list_linear = []
def linear_hook(self, input, output):
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
assert batch_size == 1
weight_ops = self.weight.nelement()
bias_ops = self.bias.nelement()
# overall macs count is equal to the number of parameters in layer
macs = batch_size * (weight_ops + bias_ops)
list_linear.append(macs)
def foo(net):
if net.__class__.__name__ == 'Conv2dStaticSamePadding':
net.register_forward_hook(conv2d_hook)
childrens = list(net.children())
if not childrens:
if isinstance(net, nn.Conv2d):
net.register_forward_hook(conv2d_hook)
elif isinstance(net, nn.Linear):
net.register_forward_hook(linear_hook)
else:
print('Warning: flop of module {} is not counted!'.format(net))
return
for c in childrens:
foo(c)
# Register hook
foo(model)
device = next(model.parameters()).device
input = torch.rand(spec_size).to(device)
with torch.no_grad():
model(input)
total_macs = sum(list_conv2d) + sum(list_linear)
print("*************Computational Complexity (multiply-adds) **************")
print("Number of Convolutional Layers: ", len(list_conv2d))
print("Number of Linear Layers: ", len(list_linear))
print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs)))
print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs))
print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9))
print("********************************************************************")
return total_macs
def count_macs_transformer(model, spec_size):
"""Count macs. Code modified from others' implementation.
"""
list_conv2d = []
def conv2d_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
assert batch_size == 1
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
bias_ops = 1 if self.bias is not None else 0
params = output_channels * (kernel_ops + bias_ops)
# overall macs count is:
# kernel**2 * in_channels/groups * out_channels * out_width * out_height
macs = batch_size * params * output_height * output_width
list_conv2d.append(macs)
list_linear = []
def linear_hook(self, input, output):
batch_size = input[0].size(0) if input[0].dim() >= 2 else 1
assert batch_size == 1
if input[0].dim() == 3:
# (batch size, sequence length, embeddings size)
batch_size, seq_len, embed_size = input[0].size()
weight_ops = self.weight.nelement()
bias_ops = self.bias.nelement() if self.bias is not None else 0
# linear layer applied position-wise, multiply with sequence length
macs = batch_size * (weight_ops + bias_ops) * seq_len
else:
# classification head
# (batch size, embeddings size)
batch_size, embed_size = input[0].size()
weight_ops = self.weight.nelement()
bias_ops = self.bias.nelement() if self.bias is not None else 0
# overall macs count is equal to the number of parameters in layer
macs = batch_size * (weight_ops + bias_ops)
list_linear.append(macs)
list_att = []
def attention_hook(self, input, output):
# here we only calculate the attention macs; linear layers are processed in linear_hook
batch_size, seq_len, embed_size = input[0].size()
# 2 times embed_size * seq_len**2
# - computing the attention matrix: embed_size * seq_len**2
# - multiply attention matrix with value matrix: embed_size * seq_len**2
macs = batch_size * embed_size * seq_len * seq_len * 2
list_att.append(macs)
def foo(net):
childrens = list(net.children())
if net.__class__.__name__ == "MultiHeadAttention":
net.register_forward_hook(attention_hook)
if not childrens:
if isinstance(net, nn.Conv2d):
net.register_forward_hook(conv2d_hook)
elif isinstance(net, nn.Linear):
net.register_forward_hook(linear_hook)
else:
print('Warning: flop of module {} is not counted!'.format(net))
return
for c in childrens:
foo(c)
# Register hook
foo(model)
device = next(model.parameters()).device
input = torch.rand(spec_size).to(device)
with torch.no_grad():
model(input)
total_macs = sum(list_conv2d) + sum(list_linear) + sum(list_att)
print("*************Computational Complexity (multiply-adds) **************")
print("Number of Convolutional Layers: ", len(list_conv2d))
print("Number of Linear Layers: ", len(list_linear))
print("Number of Attention Layers: ", len(list_att))
print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs)))
print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs))
print("Relative Share of Attention Layers: {:.2f}".format(sum(list_att) / total_macs))
print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9))
print("********************************************************************")
return total_macs