aps's picture
Commit efficientat
4848335
raw
history blame
6.17 kB
from typing import Dict, Callable, List
import torch
import torch.nn as nn
from torch import Tensor
from torchvision.ops.misc import ConvNormActivation
from efficientat.models.utils import make_divisible, cnn_out_size
class ConcurrentSEBlock(torch.nn.Module):
def __init__(
self,
c_dim: int,
f_dim: int,
t_dim: int,
se_cnf: Dict
) -> None:
super().__init__()
dims = [c_dim, f_dim, t_dim]
self.conc_se_layers = nn.ModuleList()
for d in se_cnf['se_dims']:
input_dim = dims[d-1]
squeeze_dim = make_divisible(input_dim // se_cnf['se_r'], 8)
self.conc_se_layers.append(SqueezeExcitation(input_dim, squeeze_dim, d))
if se_cnf['se_agg'] == "max":
self.agg_op = lambda x: torch.max(x, dim=0)[0]
elif se_cnf['se_agg'] == "avg":
self.agg_op = lambda x: torch.mean(x, dim=0)
elif se_cnf['se_agg'] == "add":
self.agg_op = lambda x: torch.sum(x, dim=0)
elif se_cnf['se_agg'] == "min":
self.agg_op = lambda x: torch.min(x, dim=0)[0]
else:
raise NotImplementedError(f"SE aggregation operation '{self.agg_op}' not implemented")
def forward(self, input: Tensor) -> Tensor:
# apply all concurrent se layers
se_outs = []
for se_layer in self.conc_se_layers:
se_outs.append(se_layer(input))
out = self.agg_op(torch.stack(se_outs, dim=0))
return out
class SqueezeExcitation(torch.nn.Module):
"""
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507.
Args:
input_dim (int): Input dimension
squeeze_dim (int): Size of Bottleneck
activation (Callable): activation applied to bottleneck
scale_activation (Callable): activation applied to the output
"""
def __init__(
self,
input_dim: int,
squeeze_dim: int,
se_dim: int,
activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(input_dim, squeeze_dim)
self.fc2 = torch.nn.Linear(squeeze_dim, input_dim)
assert se_dim in [1, 2, 3]
self.se_dim = [1, 2, 3]
self.se_dim.remove(se_dim)
self.activation = activation()
self.scale_activation = scale_activation()
def _scale(self, input: Tensor) -> Tensor:
scale = torch.mean(input, self.se_dim, keepdim=True)
shape = scale.size()
scale = self.fc1(scale.squeeze(2).squeeze(2))
scale = self.activation(scale)
scale = self.fc2(scale)
scale = scale
return self.scale_activation(scale).view(shape)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input
class InvertedResidualConfig:
# Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
def __init__(
self,
input_channels: int,
kernel: int,
expanded_channels: int,
out_channels: int,
use_se: bool,
activation: str,
stride: int,
dilation: int,
width_mult: float,
):
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
self.out_channels = self.adjust_channels(out_channels, width_mult)
self.use_se = use_se
self.use_hs = activation == "HS"
self.stride = stride
self.dilation = dilation
self.f_dim = None
self.t_dim = None
@staticmethod
def adjust_channels(channels: int, width_mult: float):
return make_divisible(channels * width_mult, 8)
def out_size(self, in_size):
padding = (self.kernel - 1) // 2 * self.dilation
return cnn_out_size(in_size, padding, self.dilation, self.kernel, self.stride)
class InvertedResidual(nn.Module):
def __init__(
self,
cnf: InvertedResidualConfig,
se_cnf: Dict,
norm_layer: Callable[..., nn.Module],
depthwise_norm_layer: Callable[..., nn.Module]
):
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError("illegal stride value")
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
layers: List[nn.Module] = []
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(
ConvNormActivation(
cnf.input_channels,
cnf.expanded_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(
ConvNormActivation(
cnf.expanded_channels,
cnf.expanded_channels,
kernel_size=cnf.kernel,
stride=stride,
dilation=cnf.dilation,
groups=cnf.expanded_channels,
norm_layer=depthwise_norm_layer,
activation_layer=activation_layer,
)
)
if cnf.use_se and se_cnf['se_dims'] is not None:
layers.append(ConcurrentSEBlock(cnf.expanded_channels, cnf.f_dim, cnf.t_dim, se_cnf))
# project
layers.append(
ConvNormActivation(
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
self._is_cn = cnf.stride > 1
def forward(self, inp: Tensor) -> Tensor:
result = self.block(inp)
if self.use_res_connect:
result += inp
return result