AnyChat / efficientat /models /block_types.py
enoreyes's picture
Duplicate from team7/talk_with_wind
9aa735a
from typing import Dict, Callable, List
import torch
import torch.nn as nn
from torch import Tensor
from torchvision.ops.misc import ConvNormActivation
from efficientat.models.utils import make_divisible, cnn_out_size
class ConcurrentSEBlock(torch.nn.Module):
def __init__(
self,
c_dim: int,
f_dim: int,
t_dim: int,
se_cnf: Dict
) -> None:
super().__init__()
dims = [c_dim, f_dim, t_dim]
self.conc_se_layers = nn.ModuleList()
for d in se_cnf['se_dims']:
input_dim = dims[d-1]
squeeze_dim = make_divisible(input_dim // se_cnf['se_r'], 8)
self.conc_se_layers.append(SqueezeExcitation(input_dim, squeeze_dim, d))
if se_cnf['se_agg'] == "max":
self.agg_op = lambda x: torch.max(x, dim=0)[0]
elif se_cnf['se_agg'] == "avg":
self.agg_op = lambda x: torch.mean(x, dim=0)
elif se_cnf['se_agg'] == "add":
self.agg_op = lambda x: torch.sum(x, dim=0)
elif se_cnf['se_agg'] == "min":
self.agg_op = lambda x: torch.min(x, dim=0)[0]
else:
raise NotImplementedError(f"SE aggregation operation '{self.agg_op}' not implemented")
def forward(self, input: Tensor) -> Tensor:
# apply all concurrent se layers
se_outs = []
for se_layer in self.conc_se_layers:
se_outs.append(se_layer(input))
out = self.agg_op(torch.stack(se_outs, dim=0))
return out
class SqueezeExcitation(torch.nn.Module):
"""
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507.
Args:
input_dim (int): Input dimension
squeeze_dim (int): Size of Bottleneck
activation (Callable): activation applied to bottleneck
scale_activation (Callable): activation applied to the output
"""
def __init__(
self,
input_dim: int,
squeeze_dim: int,
se_dim: int,
activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(input_dim, squeeze_dim)
self.fc2 = torch.nn.Linear(squeeze_dim, input_dim)
assert se_dim in [1, 2, 3]
self.se_dim = [1, 2, 3]
self.se_dim.remove(se_dim)
self.activation = activation()
self.scale_activation = scale_activation()
def _scale(self, input: Tensor) -> Tensor:
scale = torch.mean(input, self.se_dim, keepdim=True)
shape = scale.size()
scale = self.fc1(scale.squeeze(2).squeeze(2))
scale = self.activation(scale)
scale = self.fc2(scale)
scale = scale
return self.scale_activation(scale).view(shape)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input
class InvertedResidualConfig:
# Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
def __init__(
self,
input_channels: int,
kernel: int,
expanded_channels: int,
out_channels: int,
use_se: bool,
activation: str,
stride: int,
dilation: int,
width_mult: float,
):
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
self.out_channels = self.adjust_channels(out_channels, width_mult)
self.use_se = use_se
self.use_hs = activation == "HS"
self.stride = stride
self.dilation = dilation
self.f_dim = None
self.t_dim = None
@staticmethod
def adjust_channels(channels: int, width_mult: float):
return make_divisible(channels * width_mult, 8)
def out_size(self, in_size):
padding = (self.kernel - 1) // 2 * self.dilation
return cnn_out_size(in_size, padding, self.dilation, self.kernel, self.stride)
class InvertedResidual(nn.Module):
def __init__(
self,
cnf: InvertedResidualConfig,
se_cnf: Dict,
norm_layer: Callable[..., nn.Module],
depthwise_norm_layer: Callable[..., nn.Module]
):
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError("illegal stride value")
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
layers: List[nn.Module] = []
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(
ConvNormActivation(
cnf.input_channels,
cnf.expanded_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(
ConvNormActivation(
cnf.expanded_channels,
cnf.expanded_channels,
kernel_size=cnf.kernel,
stride=stride,
dilation=cnf.dilation,
groups=cnf.expanded_channels,
norm_layer=depthwise_norm_layer,
activation_layer=activation_layer,
)
)
if cnf.use_se and se_cnf['se_dims'] is not None:
layers.append(ConcurrentSEBlock(cnf.expanded_channels, cnf.f_dim, cnf.t_dim, se_cnf))
# project
layers.append(
ConvNormActivation(
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
self._is_cn = cnf.stride > 1
def forward(self, inp: Tensor) -> Tensor:
result = self.block(inp)
if self.use_res_connect:
result += inp
return result