File size: 5,443 Bytes
aea73e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# -*- coding: utf-8 -*-
# PyTorch Implementation of Attention Modules
#
# Implementation based on: https://github.com/mahmoodlab/CLAM
# @ Fabian Hörst, [email protected]
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen

from typing import Tuple
import torch
import torch.nn as nn


class Attention(nn.Module):
    """Basic Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL

    Args:
        in_features (int, optional): Input shape of attention module. Defaults to 1024.
        attention_features (int, optional): Number of attention features. Defaults to 128.
        num_classes (int, optional): Number of output classes. Defaults to 2.
        dropout (bool, optional):  If True, dropout is used. Defaults to False.
        dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true.
            Needs to be between 0.0 and 1.0. Defaults to 0.25.
    """

    def __init__(
        self,
        in_features: int = 1024,
        attention_features: int = 128,
        num_classes: int = 2,
        dropout: bool = False,
        dropout_rate: float = 0.25,
    ):
        super(Attention, self).__init__()
        # naming
        self.model_name = "AttentionModule"

        # set parameter dimensions for attention
        self.attention_features = attention_features
        self.in_features = in_features
        self.num_classes = num_classes
        self.dropout = dropout
        self.d_rate = dropout_rate

        if self.dropout:
            assert self.d_rate < 1
            self.attention = nn.Sequential(
                nn.Linear(self.in_features, self.attention_features),
                nn.Tanh(),
                nn.Dropout(self.d_rate),
                nn.Linear(self.attention_features, self.num_classes),
            )
        else:
            self.attention = nn.Sequential(
                nn.Linear(self.in_features, self.attention_features),
                nn.Tanh(),
                nn.Linear(self.attention_features, self.num_classes),
            )

    def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass, calculating attention scores for given input vector

        Args:
            H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions)

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:

            * Attention-Scores
            * H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions)
        """
        A = self.attention(H)
        return A, H


class AttentionGated(nn.Module):
    """Gated Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL

    Args:
        in_features (int, optional): Input shape of attention module. Defaults to 1024.
        attention_features (int, optional): Number of attention features. Defaults to 128.
        num_classes (int, optional): Number of output classes. Defaults to 2.
        dropout (bool, optional):  If True, dropout is used. Defaults to False.
        dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true.
            needs to be between 0.0 and 1.0. Defaults to 0.25.
    """

    def __init__(
        self,
        in_features: int = 1024,
        attention_features: int = 128,
        num_classes: int = 2,
        dropout: bool = False,
        dropout_rate: float = 0.25,
    ):
        super(AttentionGated, self).__init__()
        # naming
        self.model_name = "AttentionModuleGated"

        # set Parameter dimensions for attention
        self.attention_features = attention_features
        self.in_features = in_features
        self.num_classes = num_classes
        self.dropout = dropout
        self.d_rate = dropout_rate

        if self.dropout:
            assert self.d_rate < 1
            self.attention_V = nn.Sequential(
                nn.Linear(self.in_features, self.attention_features),
                nn.Tanh(),
                nn.Dropout(self.d_rate),
            )
            self.attention_U = nn.Sequential(
                nn.Linear(self.in_features, self.attention_features),
                nn.Sigmoid(),
                nn.Dropout(self.d_rate),
            )
            self.attention_W = nn.Sequential(
                nn.Linear(self.attention_features, self.num_classes)
            )

        else:
            self.attention_V = nn.Sequential(
                nn.Linear(self.in_features, self.attention_features), nn.Tanh()
            )
            self.attention_U = nn.Sequential(
                nn.Linear(self.in_features, self.attention_features), nn.Sigmoid()
            )
            self.attention_W = nn.Sequential(
                nn.Linear(self.attention_features, self.num_classes)
            )

    def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass, calculating attention scores for given input vector

        Args:
            H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions)

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:

            * Attention-Scores. Shape: (Number of instances)
            * H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions)
        """
        v = self.attention_V(H)
        u = self.attention_U(H)
        A = self.attention_W(v * u)
        return A, H