File size: 6,503 Bytes
4848335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
import torch.nn as nn


# adapted from PANNs (https://github.com/qiuqiangkong/audioset_tagging_cnn)

def count_macs(model, spec_size):
    list_conv2d = []

    def conv2d_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        assert batch_size == 1
        output_channels, output_height, output_width = output[0].size()

        kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
        bias_ops = 1 if self.bias is not None else 0

        params = output_channels * (kernel_ops + bias_ops)
        # overall macs count is:
        # kernel**2 * in_channels/groups * out_channels * out_width * out_height
        macs = batch_size * params * output_height * output_width

        list_conv2d.append(macs)

    list_linear = []

    def linear_hook(self, input, output):
        batch_size = input[0].size(0) if input[0].dim() == 2 else 1
        assert batch_size == 1
        weight_ops = self.weight.nelement()
        bias_ops = self.bias.nelement()

        # overall macs count is equal to the number of parameters in layer
        macs = batch_size * (weight_ops + bias_ops)
        list_linear.append(macs)

    def foo(net):
        if net.__class__.__name__ == 'Conv2dStaticSamePadding':
            net.register_forward_hook(conv2d_hook)
        childrens = list(net.children())
        if not childrens:
            if isinstance(net, nn.Conv2d):
                net.register_forward_hook(conv2d_hook)
            elif isinstance(net, nn.Linear):
                net.register_forward_hook(linear_hook)
            else:
                print('Warning: flop of module {} is not counted!'.format(net))
            return
        for c in childrens:
            foo(c)

    # Register hook
    foo(model)

    device = next(model.parameters()).device
    input = torch.rand(spec_size).to(device)
    with torch.no_grad():
        model(input)

    total_macs = sum(list_conv2d) + sum(list_linear)

    print("*************Computational Complexity (multiply-adds) **************")
    print("Number of Convolutional Layers: ", len(list_conv2d))
    print("Number of Linear Layers: ", len(list_linear))
    print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs)))
    print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs))
    print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9))
    print("********************************************************************")
    return total_macs


def count_macs_transformer(model, spec_size):
    """Count macs. Code modified from others' implementation.
        """
    list_conv2d = []

    def conv2d_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        assert batch_size == 1
        output_channels, output_height, output_width = output[0].size()

        kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
        bias_ops = 1 if self.bias is not None else 0

        params = output_channels * (kernel_ops + bias_ops)
        # overall macs count is:
        # kernel**2 * in_channels/groups * out_channels * out_width * out_height
        macs = batch_size * params * output_height * output_width

        list_conv2d.append(macs)

    list_linear = []

    def linear_hook(self, input, output):
        batch_size = input[0].size(0) if input[0].dim() >= 2 else 1
        assert batch_size == 1
        if input[0].dim() == 3:
            # (batch size, sequence length, embeddings size)
            batch_size, seq_len, embed_size = input[0].size()

            weight_ops = self.weight.nelement()
            bias_ops = self.bias.nelement() if self.bias is not None else 0
            # linear layer applied position-wise, multiply with sequence length
            macs = batch_size * (weight_ops + bias_ops) * seq_len
        else:
            # classification head
            # (batch size, embeddings size)
            batch_size, embed_size = input[0].size()
            weight_ops = self.weight.nelement()
            bias_ops = self.bias.nelement() if self.bias is not None else 0
            # overall macs count is equal to the number of parameters in layer
            macs = batch_size * (weight_ops + bias_ops)
        list_linear.append(macs)

    list_att = []

    def attention_hook(self, input, output):
        # here we only calculate the attention macs; linear layers are processed in linear_hook
        batch_size, seq_len, embed_size = input[0].size()

        # 2 times embed_size * seq_len**2
        # - computing the attention matrix: embed_size * seq_len**2
        # - multiply attention matrix with value matrix: embed_size * seq_len**2
        macs = batch_size * embed_size * seq_len * seq_len * 2
        list_att.append(macs)

    def foo(net):
        childrens = list(net.children())
        if net.__class__.__name__ == "MultiHeadAttention":
            net.register_forward_hook(attention_hook)
        if not childrens:
            if isinstance(net, nn.Conv2d):
                net.register_forward_hook(conv2d_hook)
            elif isinstance(net, nn.Linear):
                net.register_forward_hook(linear_hook)
            else:
                print('Warning: flop of module {} is not counted!'.format(net))
            return
        for c in childrens:
            foo(c)

    # Register hook
    foo(model)

    device = next(model.parameters()).device
    input = torch.rand(spec_size).to(device)

    with torch.no_grad():
        model(input)

    total_macs = sum(list_conv2d) + sum(list_linear) + sum(list_att)

    print("*************Computational Complexity (multiply-adds) **************")
    print("Number of Convolutional Layers: ", len(list_conv2d))
    print("Number of Linear Layers: ", len(list_linear))
    print("Number of Attention Layers: ", len(list_att))
    print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs)))
    print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs))
    print("Relative Share of Attention Layers: {:.2f}".format(sum(list_att) / total_macs))
    print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9))
    print("********************************************************************")
    return total_macs