|
from typing import Dict, Callable, List |
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torchvision.ops.misc import ConvNormActivation |
|
|
|
from efficientat.models.utils import make_divisible, cnn_out_size |
|
|
|
|
|
|
|
class ConcurrentSEBlock(torch.nn.Module): |
|
def __init__( |
|
self, |
|
c_dim: int, |
|
f_dim: int, |
|
t_dim: int, |
|
se_cnf: Dict |
|
) -> None: |
|
super().__init__() |
|
dims = [c_dim, f_dim, t_dim] |
|
self.conc_se_layers = nn.ModuleList() |
|
for d in se_cnf['se_dims']: |
|
input_dim = dims[d-1] |
|
squeeze_dim = make_divisible(input_dim // se_cnf['se_r'], 8) |
|
self.conc_se_layers.append(SqueezeExcitation(input_dim, squeeze_dim, d)) |
|
if se_cnf['se_agg'] == "max": |
|
self.agg_op = lambda x: torch.max(x, dim=0)[0] |
|
elif se_cnf['se_agg'] == "avg": |
|
self.agg_op = lambda x: torch.mean(x, dim=0) |
|
elif se_cnf['se_agg'] == "add": |
|
self.agg_op = lambda x: torch.sum(x, dim=0) |
|
elif se_cnf['se_agg'] == "min": |
|
self.agg_op = lambda x: torch.min(x, dim=0)[0] |
|
else: |
|
raise NotImplementedError(f"SE aggregation operation '{self.agg_op}' not implemented") |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
|
|
se_outs = [] |
|
for se_layer in self.conc_se_layers: |
|
se_outs.append(se_layer(input)) |
|
out = self.agg_op(torch.stack(se_outs, dim=0)) |
|
return out |
|
|
|
|
|
class SqueezeExcitation(torch.nn.Module): |
|
""" |
|
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507. |
|
Args: |
|
input_dim (int): Input dimension |
|
squeeze_dim (int): Size of Bottleneck |
|
activation (Callable): activation applied to bottleneck |
|
scale_activation (Callable): activation applied to the output |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim: int, |
|
squeeze_dim: int, |
|
se_dim: int, |
|
activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, |
|
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, |
|
) -> None: |
|
super().__init__() |
|
self.fc1 = torch.nn.Linear(input_dim, squeeze_dim) |
|
self.fc2 = torch.nn.Linear(squeeze_dim, input_dim) |
|
assert se_dim in [1, 2, 3] |
|
self.se_dim = [1, 2, 3] |
|
self.se_dim.remove(se_dim) |
|
self.activation = activation() |
|
self.scale_activation = scale_activation() |
|
|
|
def _scale(self, input: Tensor) -> Tensor: |
|
scale = torch.mean(input, self.se_dim, keepdim=True) |
|
shape = scale.size() |
|
scale = self.fc1(scale.squeeze(2).squeeze(2)) |
|
scale = self.activation(scale) |
|
scale = self.fc2(scale) |
|
scale = scale |
|
return self.scale_activation(scale).view(shape) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
scale = self._scale(input) |
|
return scale * input |
|
|
|
|
|
class InvertedResidualConfig: |
|
|
|
def __init__( |
|
self, |
|
input_channels: int, |
|
kernel: int, |
|
expanded_channels: int, |
|
out_channels: int, |
|
use_se: bool, |
|
activation: str, |
|
stride: int, |
|
dilation: int, |
|
width_mult: float, |
|
): |
|
self.input_channels = self.adjust_channels(input_channels, width_mult) |
|
self.kernel = kernel |
|
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) |
|
self.out_channels = self.adjust_channels(out_channels, width_mult) |
|
self.use_se = use_se |
|
self.use_hs = activation == "HS" |
|
self.stride = stride |
|
self.dilation = dilation |
|
self.f_dim = None |
|
self.t_dim = None |
|
|
|
@staticmethod |
|
def adjust_channels(channels: int, width_mult: float): |
|
return make_divisible(channels * width_mult, 8) |
|
|
|
def out_size(self, in_size): |
|
padding = (self.kernel - 1) // 2 * self.dilation |
|
return cnn_out_size(in_size, padding, self.dilation, self.kernel, self.stride) |
|
|
|
|
|
class InvertedResidual(nn.Module): |
|
def __init__( |
|
self, |
|
cnf: InvertedResidualConfig, |
|
se_cnf: Dict, |
|
norm_layer: Callable[..., nn.Module], |
|
depthwise_norm_layer: Callable[..., nn.Module] |
|
): |
|
super().__init__() |
|
if not (1 <= cnf.stride <= 2): |
|
raise ValueError("illegal stride value") |
|
|
|
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels |
|
|
|
layers: List[nn.Module] = [] |
|
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU |
|
|
|
|
|
if cnf.expanded_channels != cnf.input_channels: |
|
layers.append( |
|
ConvNormActivation( |
|
cnf.input_channels, |
|
cnf.expanded_channels, |
|
kernel_size=1, |
|
norm_layer=norm_layer, |
|
activation_layer=activation_layer, |
|
) |
|
) |
|
|
|
|
|
stride = 1 if cnf.dilation > 1 else cnf.stride |
|
layers.append( |
|
ConvNormActivation( |
|
cnf.expanded_channels, |
|
cnf.expanded_channels, |
|
kernel_size=cnf.kernel, |
|
stride=stride, |
|
dilation=cnf.dilation, |
|
groups=cnf.expanded_channels, |
|
norm_layer=depthwise_norm_layer, |
|
activation_layer=activation_layer, |
|
) |
|
) |
|
if cnf.use_se and se_cnf['se_dims'] is not None: |
|
layers.append(ConcurrentSEBlock(cnf.expanded_channels, cnf.f_dim, cnf.t_dim, se_cnf)) |
|
|
|
|
|
layers.append( |
|
ConvNormActivation( |
|
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None |
|
) |
|
) |
|
|
|
self.block = nn.Sequential(*layers) |
|
self.out_channels = cnf.out_channels |
|
self._is_cn = cnf.stride > 1 |
|
|
|
def forward(self, inp: Tensor) -> Tensor: |
|
result = self.block(inp) |
|
if self.use_res_connect: |
|
result += inp |
|
return result |
|
|