|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class SelectiveKernelFeatureFusion(nn.Module): |
|
""" |
|
Merges outputs of the three different resolutions through self-attention. |
|
|
|
All three inputs are summed -> global average pooling -> downscaling -> the signal is passed through 3 different convs to have three descriptors, |
|
softmax is applied to each descriptor to get 3 attention activations used to recalibrate the three input feature maps. |
|
""" |
|
|
|
def __init__(self, in_channels, reduction_ratio, bias=False): |
|
super().__init__() |
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
conv_out_channels = max(int(in_channels / reduction_ratio), 4) |
|
self.convolution = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels, conv_out_channels, kernel_size=1, padding=0, bias=bias |
|
), |
|
nn.PReLU(), |
|
) |
|
|
|
self.attention_convs = nn.ModuleList([]) |
|
for i in range(3): |
|
self.attention_convs.append( |
|
nn.Conv2d( |
|
conv_out_channels, in_channels, kernel_size=1, stride=1, bias=bias |
|
) |
|
) |
|
|
|
self.softmax = nn.Softmax(dim=1) |
|
|
|
def forward(self, x): |
|
batch_size = x[0].shape[0] |
|
n_features = x[0].shape[1] |
|
|
|
x = torch.cat( |
|
x, dim=1 |
|
) |
|
x = x.view( |
|
batch_size, 3, n_features, x.shape[2], x.shape[3] |
|
) |
|
|
|
z = torch.sum(x, dim=1) |
|
z = self.avg_pool(z) |
|
z = self.convolution(z) |
|
|
|
attention_activations = [ |
|
atn(z) for atn in self.attention_convs |
|
] |
|
attention_activations = torch.cat( |
|
attention_activations, dim=1 |
|
) |
|
attention_activations = attention_activations.view( |
|
batch_size, 3, n_features, 1, 1 |
|
) |
|
|
|
attention_activations = self.softmax( |
|
attention_activations |
|
) |
|
|
|
return torch.sum( |
|
x * attention_activations, dim=1 |
|
) |
|
|