# -*- coding: utf-8 -*- # PyTorch Implementation of Attention Modules # # Implementation based on: https://github.com/mahmoodlab/CLAM # @ Fabian Hörst, fabian.hoerst@uk-essen.de # Institute for Artifical Intelligence in Medicine, # University Medicine Essen from typing import Tuple import torch import torch.nn as nn class Attention(nn.Module): """Basic Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL Args: in_features (int, optional): Input shape of attention module. Defaults to 1024. attention_features (int, optional): Number of attention features. Defaults to 128. num_classes (int, optional): Number of output classes. Defaults to 2. dropout (bool, optional): If True, dropout is used. Defaults to False. dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true. Needs to be between 0.0 and 1.0. Defaults to 0.25. """ def __init__( self, in_features: int = 1024, attention_features: int = 128, num_classes: int = 2, dropout: bool = False, dropout_rate: float = 0.25, ): super(Attention, self).__init__() # naming self.model_name = "AttentionModule" # set parameter dimensions for attention self.attention_features = attention_features self.in_features = in_features self.num_classes = num_classes self.dropout = dropout self.d_rate = dropout_rate if self.dropout: assert self.d_rate < 1 self.attention = nn.Sequential( nn.Linear(self.in_features, self.attention_features), nn.Tanh(), nn.Dropout(self.d_rate), nn.Linear(self.attention_features, self.num_classes), ) else: self.attention = nn.Sequential( nn.Linear(self.in_features, self.attention_features), nn.Tanh(), nn.Linear(self.attention_features, self.num_classes), ) def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass, calculating attention scores for given input vector Args: H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions) Returns: Tuple[torch.Tensor, torch.Tensor]: * Attention-Scores * H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions) """ A = self.attention(H) return A, H class AttentionGated(nn.Module): """Gated Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL Args: in_features (int, optional): Input shape of attention module. Defaults to 1024. attention_features (int, optional): Number of attention features. Defaults to 128. num_classes (int, optional): Number of output classes. Defaults to 2. dropout (bool, optional): If True, dropout is used. Defaults to False. dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true. needs to be between 0.0 and 1.0. Defaults to 0.25. """ def __init__( self, in_features: int = 1024, attention_features: int = 128, num_classes: int = 2, dropout: bool = False, dropout_rate: float = 0.25, ): super(AttentionGated, self).__init__() # naming self.model_name = "AttentionModuleGated" # set Parameter dimensions for attention self.attention_features = attention_features self.in_features = in_features self.num_classes = num_classes self.dropout = dropout self.d_rate = dropout_rate if self.dropout: assert self.d_rate < 1 self.attention_V = nn.Sequential( nn.Linear(self.in_features, self.attention_features), nn.Tanh(), nn.Dropout(self.d_rate), ) self.attention_U = nn.Sequential( nn.Linear(self.in_features, self.attention_features), nn.Sigmoid(), nn.Dropout(self.d_rate), ) self.attention_W = nn.Sequential( nn.Linear(self.attention_features, self.num_classes) ) else: self.attention_V = nn.Sequential( nn.Linear(self.in_features, self.attention_features), nn.Tanh() ) self.attention_U = nn.Sequential( nn.Linear(self.in_features, self.attention_features), nn.Sigmoid() ) self.attention_W = nn.Sequential( nn.Linear(self.attention_features, self.num_classes) ) def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass, calculating attention scores for given input vector Args: H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions) Returns: Tuple[torch.Tensor, torch.Tensor]: * Attention-Scores. Shape: (Number of instances) * H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions) """ v = self.attention_V(H) u = self.attention_U(H) A = self.attention_W(v * u) return A, H