Spaces:
Runtime error
Runtime error
from typing import Dict, Callable, List | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from torchvision.ops.misc import ConvNormActivation | |
from efficientat.models.utils import make_divisible, cnn_out_size | |
class ConcurrentSEBlock(torch.nn.Module): | |
def __init__( | |
self, | |
c_dim: int, | |
f_dim: int, | |
t_dim: int, | |
se_cnf: Dict | |
) -> None: | |
super().__init__() | |
dims = [c_dim, f_dim, t_dim] | |
self.conc_se_layers = nn.ModuleList() | |
for d in se_cnf['se_dims']: | |
input_dim = dims[d-1] | |
squeeze_dim = make_divisible(input_dim // se_cnf['se_r'], 8) | |
self.conc_se_layers.append(SqueezeExcitation(input_dim, squeeze_dim, d)) | |
if se_cnf['se_agg'] == "max": | |
self.agg_op = lambda x: torch.max(x, dim=0)[0] | |
elif se_cnf['se_agg'] == "avg": | |
self.agg_op = lambda x: torch.mean(x, dim=0) | |
elif se_cnf['se_agg'] == "add": | |
self.agg_op = lambda x: torch.sum(x, dim=0) | |
elif se_cnf['se_agg'] == "min": | |
self.agg_op = lambda x: torch.min(x, dim=0)[0] | |
else: | |
raise NotImplementedError(f"SE aggregation operation '{self.agg_op}' not implemented") | |
def forward(self, input: Tensor) -> Tensor: | |
# apply all concurrent se layers | |
se_outs = [] | |
for se_layer in self.conc_se_layers: | |
se_outs.append(se_layer(input)) | |
out = self.agg_op(torch.stack(se_outs, dim=0)) | |
return out | |
class SqueezeExcitation(torch.nn.Module): | |
""" | |
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507. | |
Args: | |
input_dim (int): Input dimension | |
squeeze_dim (int): Size of Bottleneck | |
activation (Callable): activation applied to bottleneck | |
scale_activation (Callable): activation applied to the output | |
""" | |
def __init__( | |
self, | |
input_dim: int, | |
squeeze_dim: int, | |
se_dim: int, | |
activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, | |
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, | |
) -> None: | |
super().__init__() | |
self.fc1 = torch.nn.Linear(input_dim, squeeze_dim) | |
self.fc2 = torch.nn.Linear(squeeze_dim, input_dim) | |
assert se_dim in [1, 2, 3] | |
self.se_dim = [1, 2, 3] | |
self.se_dim.remove(se_dim) | |
self.activation = activation() | |
self.scale_activation = scale_activation() | |
def _scale(self, input: Tensor) -> Tensor: | |
scale = torch.mean(input, self.se_dim, keepdim=True) | |
shape = scale.size() | |
scale = self.fc1(scale.squeeze(2).squeeze(2)) | |
scale = self.activation(scale) | |
scale = self.fc2(scale) | |
scale = scale | |
return self.scale_activation(scale).view(shape) | |
def forward(self, input: Tensor) -> Tensor: | |
scale = self._scale(input) | |
return scale * input | |
class InvertedResidualConfig: | |
# Stores information listed at Tables 1 and 2 of the MobileNetV3 paper | |
def __init__( | |
self, | |
input_channels: int, | |
kernel: int, | |
expanded_channels: int, | |
out_channels: int, | |
use_se: bool, | |
activation: str, | |
stride: int, | |
dilation: int, | |
width_mult: float, | |
): | |
self.input_channels = self.adjust_channels(input_channels, width_mult) | |
self.kernel = kernel | |
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) | |
self.out_channels = self.adjust_channels(out_channels, width_mult) | |
self.use_se = use_se | |
self.use_hs = activation == "HS" | |
self.stride = stride | |
self.dilation = dilation | |
self.f_dim = None | |
self.t_dim = None | |
def adjust_channels(channels: int, width_mult: float): | |
return make_divisible(channels * width_mult, 8) | |
def out_size(self, in_size): | |
padding = (self.kernel - 1) // 2 * self.dilation | |
return cnn_out_size(in_size, padding, self.dilation, self.kernel, self.stride) | |
class InvertedResidual(nn.Module): | |
def __init__( | |
self, | |
cnf: InvertedResidualConfig, | |
se_cnf: Dict, | |
norm_layer: Callable[..., nn.Module], | |
depthwise_norm_layer: Callable[..., nn.Module] | |
): | |
super().__init__() | |
if not (1 <= cnf.stride <= 2): | |
raise ValueError("illegal stride value") | |
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels | |
layers: List[nn.Module] = [] | |
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU | |
# expand | |
if cnf.expanded_channels != cnf.input_channels: | |
layers.append( | |
ConvNormActivation( | |
cnf.input_channels, | |
cnf.expanded_channels, | |
kernel_size=1, | |
norm_layer=norm_layer, | |
activation_layer=activation_layer, | |
) | |
) | |
# depthwise | |
stride = 1 if cnf.dilation > 1 else cnf.stride | |
layers.append( | |
ConvNormActivation( | |
cnf.expanded_channels, | |
cnf.expanded_channels, | |
kernel_size=cnf.kernel, | |
stride=stride, | |
dilation=cnf.dilation, | |
groups=cnf.expanded_channels, | |
norm_layer=depthwise_norm_layer, | |
activation_layer=activation_layer, | |
) | |
) | |
if cnf.use_se and se_cnf['se_dims'] is not None: | |
layers.append(ConcurrentSEBlock(cnf.expanded_channels, cnf.f_dim, cnf.t_dim, se_cnf)) | |
# project | |
layers.append( | |
ConvNormActivation( | |
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None | |
) | |
) | |
self.block = nn.Sequential(*layers) | |
self.out_channels = cnf.out_channels | |
self._is_cn = cnf.stride > 1 | |
def forward(self, inp: Tensor) -> Tensor: | |
result = self.block(inp) | |
if self.use_res_connect: | |
result += inp | |
return result | |