|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
from efficientat.models.utils import collapse_dim |
|
|
|
|
|
class MultiHeadAttentionPooling(nn.Module): |
|
"""Multi-Head Attention as used in PSLA paper (https://arxiv.org/pdf/2102.01243.pdf) |
|
""" |
|
def __init__(self, in_dim, out_dim, att_activation: str = 'sigmoid', |
|
clf_activation: str = 'ident', num_heads: int = 4, epsilon: float = 1e-7): |
|
super(MultiHeadAttentionPooling, self).__init__() |
|
|
|
self.in_dim = in_dim |
|
self.out_dim = out_dim |
|
self.num_heads = num_heads |
|
self.epsilon = epsilon |
|
|
|
self.att_activation = att_activation |
|
self.clf_activation = clf_activation |
|
|
|
|
|
self.subspace_proj = nn.Linear(self.in_dim, self.out_dim * 2 * self.num_heads) |
|
self.head_weight = nn.Parameter(torch.tensor([1.0 / self.num_heads] * self.num_heads).view(1, -1, 1)) |
|
|
|
def activate(self, x, activation): |
|
if activation == 'linear': |
|
return x |
|
elif activation == 'relu': |
|
return F.relu(x) |
|
elif activation == 'sigmoid': |
|
return torch.sigmoid(x) |
|
elif activation == 'softmax': |
|
return F.softmax(x, dim=1) |
|
elif activation == 'ident': |
|
return x |
|
|
|
def forward(self, x) -> Tensor: |
|
"""x: Tensor of size (batch_size, channels, frequency bands, sequence length) |
|
""" |
|
x = collapse_dim(x, dim=2) |
|
x = x.transpose(1, 2) |
|
b, n, c = x.shape |
|
|
|
x = self.subspace_proj(x).reshape(b, n, 2, self.num_heads, self.out_dim).permute(2, 0, 3, 1, 4) |
|
att, val = x[0], x[1] |
|
val = self.activate(val, self.clf_activation) |
|
att = self.activate(att, self.att_activation) |
|
att = torch.clamp(att, self.epsilon, 1. - self.epsilon) |
|
att = att / torch.sum(att, dim=2, keepdim=True) |
|
|
|
out = torch.sum(att * val, dim=2) * self.head_weight |
|
out = torch.sum(out, dim=1) |
|
return out |
|
|