LKCell / models /utils /attention.py
xiazhi1
initial commit
aea73e2
raw
history blame
5.44 kB
# -*- coding: utf-8 -*-
# PyTorch Implementation of Attention Modules
#
# Implementation based on: https://github.com/mahmoodlab/CLAM
# @ Fabian Hörst, [email protected]
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
from typing import Tuple
import torch
import torch.nn as nn
class Attention(nn.Module):
"""Basic Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL
Args:
in_features (int, optional): Input shape of attention module. Defaults to 1024.
attention_features (int, optional): Number of attention features. Defaults to 128.
num_classes (int, optional): Number of output classes. Defaults to 2.
dropout (bool, optional): If True, dropout is used. Defaults to False.
dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true.
Needs to be between 0.0 and 1.0. Defaults to 0.25.
"""
def __init__(
self,
in_features: int = 1024,
attention_features: int = 128,
num_classes: int = 2,
dropout: bool = False,
dropout_rate: float = 0.25,
):
super(Attention, self).__init__()
# naming
self.model_name = "AttentionModule"
# set parameter dimensions for attention
self.attention_features = attention_features
self.in_features = in_features
self.num_classes = num_classes
self.dropout = dropout
self.d_rate = dropout_rate
if self.dropout:
assert self.d_rate < 1
self.attention = nn.Sequential(
nn.Linear(self.in_features, self.attention_features),
nn.Tanh(),
nn.Dropout(self.d_rate),
nn.Linear(self.attention_features, self.num_classes),
)
else:
self.attention = nn.Sequential(
nn.Linear(self.in_features, self.attention_features),
nn.Tanh(),
nn.Linear(self.attention_features, self.num_classes),
)
def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass, calculating attention scores for given input vector
Args:
H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions)
Returns:
Tuple[torch.Tensor, torch.Tensor]:
* Attention-Scores
* H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions)
"""
A = self.attention(H)
return A, H
class AttentionGated(nn.Module):
"""Gated Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL
Args:
in_features (int, optional): Input shape of attention module. Defaults to 1024.
attention_features (int, optional): Number of attention features. Defaults to 128.
num_classes (int, optional): Number of output classes. Defaults to 2.
dropout (bool, optional): If True, dropout is used. Defaults to False.
dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true.
needs to be between 0.0 and 1.0. Defaults to 0.25.
"""
def __init__(
self,
in_features: int = 1024,
attention_features: int = 128,
num_classes: int = 2,
dropout: bool = False,
dropout_rate: float = 0.25,
):
super(AttentionGated, self).__init__()
# naming
self.model_name = "AttentionModuleGated"
# set Parameter dimensions for attention
self.attention_features = attention_features
self.in_features = in_features
self.num_classes = num_classes
self.dropout = dropout
self.d_rate = dropout_rate
if self.dropout:
assert self.d_rate < 1
self.attention_V = nn.Sequential(
nn.Linear(self.in_features, self.attention_features),
nn.Tanh(),
nn.Dropout(self.d_rate),
)
self.attention_U = nn.Sequential(
nn.Linear(self.in_features, self.attention_features),
nn.Sigmoid(),
nn.Dropout(self.d_rate),
)
self.attention_W = nn.Sequential(
nn.Linear(self.attention_features, self.num_classes)
)
else:
self.attention_V = nn.Sequential(
nn.Linear(self.in_features, self.attention_features), nn.Tanh()
)
self.attention_U = nn.Sequential(
nn.Linear(self.in_features, self.attention_features), nn.Sigmoid()
)
self.attention_W = nn.Sequential(
nn.Linear(self.attention_features, self.num_classes)
)
def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass, calculating attention scores for given input vector
Args:
H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions)
Returns:
Tuple[torch.Tensor, torch.Tensor]:
* Attention-Scores. Shape: (Number of instances)
* H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions)
"""
v = self.attention_V(H)
u = self.attention_U(H)
A = self.attention_W(v * u)
return A, H