GinnM commited on
Commit
31c04ef
1 Parent(s): e0b90aa

Upload ProSSTForMaskedLM

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. model.safetensors +3 -0
  3. modeling_prosst.py +1413 -0
config.json CHANGED
@@ -1,7 +1,11 @@
1
  {
 
 
 
2
  "attention_probs_dropout_prob": 0.1,
3
  "auto_map": {
4
- "AutoConfig": "configuration_prosst.ProSSTConfig"
 
5
  },
6
  "hidden_act": "gelu",
7
  "hidden_dropout_prob": 0.1,
@@ -33,6 +37,7 @@
33
  "scale_hidden": 1,
34
  "ss_vocab_size": 11,
35
  "token_dropout": true,
 
36
  "transformers_version": "4.38.2",
37
  "type_vocab_size": 0,
38
  "vocab_size": 25
 
1
  {
2
+ "architectures": [
3
+ "ProSSTForMaskedLM"
4
+ ],
5
  "attention_probs_dropout_prob": 0.1,
6
  "auto_map": {
7
+ "AutoConfig": "configuration_prosst.ProSSTConfig",
8
+ "AutoModelForMaskedLM": "modeling_prosst.ProSSTForMaskedLM"
9
  },
10
  "hidden_act": "gelu",
11
  "hidden_dropout_prob": 0.1,
 
37
  "scale_hidden": 1,
38
  "ss_vocab_size": 11,
39
  "token_dropout": true,
40
+ "torch_dtype": "float32",
41
  "transformers_version": "4.38.2",
42
  "type_vocab_size": 0,
43
  "vocab_size": 25
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7864d390347ae8a37280a8168bc1d8d5bab38349a0388c9f55f046789c82a53c
3
+ size 462353760
modeling_prosst.py ADDED
@@ -0,0 +1,1413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+ from typing import Optional, Tuple, Union
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
+ from transformers.activations import ACT2FN
8
+ from transformers.modeling_outputs import (
9
+ BaseModelOutput,
10
+ MaskedLMOutput,
11
+ SequenceClassifierOutput,
12
+ TokenClassifierOutput,
13
+ )
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from .configuration_prosst import ProSSTConfig
16
+ import torch.nn.functional as F
17
+
18
+
19
+ def build_relative_position(query_size, key_size, device):
20
+ """
21
+ Build relative position according to the query and key
22
+
23
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
24
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
25
+ P_k\\)
26
+
27
+ Args:
28
+ query_size (int): the length of query
29
+ key_size (int): the length of key
30
+
31
+ Return:
32
+ `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
33
+
34
+ """
35
+
36
+ q_ids = torch.arange(query_size, dtype=torch.long, device=device)
37
+ k_ids = torch.arange(key_size, dtype=torch.long, device=device)
38
+ rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
39
+ rel_pos_ids = rel_pos_ids[:query_size, :]
40
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
41
+ return rel_pos_ids
42
+
43
+
44
+ @torch.jit.script
45
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
46
+ return c2p_pos.expand(
47
+ [
48
+ query_layer.size(0),
49
+ query_layer.size(1),
50
+ query_layer.size(2),
51
+ relative_pos.size(-1),
52
+ ]
53
+ )
54
+
55
+
56
+ @torch.jit.script
57
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
58
+ return c2p_pos.expand(
59
+ [
60
+ query_layer.size(0),
61
+ query_layer.size(1),
62
+ key_layer.size(-2),
63
+ key_layer.size(-2),
64
+ ]
65
+ )
66
+
67
+
68
+ @torch.jit.script
69
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
70
+ return pos_index.expand(
71
+ p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))
72
+ )
73
+
74
+
75
+ def rotate_half(x):
76
+ x1, x2 = x.chunk(2, dim=-1)
77
+ return torch.cat((-x2, x1), dim=-1)
78
+
79
+
80
+ def apply_rotary_pos_emb(x, cos, sin):
81
+ cos = cos[:, :, : x.shape[-2], :]
82
+ sin = sin[:, :, : x.shape[-2], :]
83
+
84
+ return (x * cos) + (rotate_half(x) * sin)
85
+
86
+
87
+ class RotaryEmbedding(torch.nn.Module):
88
+ """
89
+ Rotary position embeddings based on those in
90
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
91
+ matrices which depend on their relative positions.
92
+ """
93
+
94
+ def __init__(self, dim: int):
95
+ super().__init__()
96
+ # Generate and save the inverse frequency buffer (non trainable)
97
+ inv_freq = 1.0 / (
98
+ 10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
99
+ )
100
+ inv_freq = inv_freq
101
+ self.register_buffer("inv_freq", inv_freq)
102
+
103
+ self._seq_len_cached = None
104
+ self._cos_cached = None
105
+ self._sin_cached = None
106
+
107
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
108
+ seq_len = x.shape[seq_dimension]
109
+
110
+ # Reset the tables if the sequence length has changed,
111
+ # or if we're on a new device (possibly due to tracing for instance)
112
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
113
+ self._seq_len_cached = seq_len
114
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(
115
+ self.inv_freq
116
+ )
117
+ freqs = torch.outer(t, self.inv_freq)
118
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
119
+
120
+ self._cos_cached = emb.cos()[None, None, :, :]
121
+ self._sin_cached = emb.sin()[None, None, :, :]
122
+
123
+ return self._cos_cached, self._sin_cached
124
+
125
+ def forward(
126
+ self, q: torch.Tensor, k: torch.Tensor
127
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
128
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
129
+ k, seq_dimension=-2
130
+ )
131
+
132
+ return (
133
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
134
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
135
+ )
136
+
137
+
138
+ class MaskedConv1d(nn.Conv1d):
139
+ """A masked 1-dimensional convolution layer.
140
+
141
+ Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically.
142
+
143
+ Shape:
144
+ Input: (N, L, in_channels)
145
+ input_mask: (N, L, 1), optional
146
+ Output: (N, L, out_channels)
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ in_channels: int,
152
+ out_channels: int,
153
+ kernel_size: int,
154
+ stride: int = 1,
155
+ dilation: int = 1,
156
+ groups: int = 1,
157
+ bias: bool = True,
158
+ ):
159
+ """
160
+ :param in_channels: input channels
161
+ :param out_channels: output channels
162
+ :param kernel_size: the kernel width
163
+ :param stride: filter shift
164
+ :param dilation: dilation factor
165
+ :param groups: perform depth-wise convolutions
166
+ :param bias: adds learnable bias to output
167
+ """
168
+ padding = dilation * (kernel_size - 1) // 2
169
+ super().__init__(
170
+ in_channels,
171
+ out_channels,
172
+ kernel_size,
173
+ stride=stride,
174
+ dilation=dilation,
175
+ groups=groups,
176
+ bias=bias,
177
+ padding=padding,
178
+ )
179
+
180
+ def forward(self, x, input_mask=None):
181
+ if input_mask is not None:
182
+ x = x * input_mask
183
+ return super().forward(x.transpose(1, 2)).transpose(1, 2)
184
+
185
+
186
+ class Attention1dPooling(nn.Module):
187
+ def __init__(self, config):
188
+ super().__init__()
189
+ self.layer = MaskedConv1d(config.hidden_size, 1, 1)
190
+
191
+ def forward(self, x, input_mask=None):
192
+ batch_szie = x.shape[0]
193
+ attn = self.layer(x)
194
+ attn = attn.view(batch_szie, -1)
195
+ if input_mask is not None:
196
+ attn = attn.masked_fill_(
197
+ ~input_mask.view(batch_szie, -1).bool(), float("-inf")
198
+ )
199
+ attn = F.softmax(attn, dim=-1).view(batch_szie, -1, 1)
200
+ out = (attn * x).sum(dim=1)
201
+ return out
202
+
203
+
204
+ class MeanPooling(nn.Module):
205
+ """Mean Pooling for sentence-level classification tasks."""
206
+
207
+ def __init__(self):
208
+ super().__init__()
209
+
210
+ def forward(self, features, input_mask=None):
211
+ if input_mask is not None:
212
+ # Applying input_mask to zero out masked values
213
+ masked_features = features * input_mask.unsqueeze(2)
214
+ sum_features = torch.sum(masked_features, dim=1)
215
+ mean_pooled_features = sum_features / input_mask.sum(dim=1, keepdim=True)
216
+ else:
217
+ mean_pooled_features = torch.mean(features, dim=1)
218
+ return mean_pooled_features
219
+
220
+
221
+ class ContextPooler(nn.Module):
222
+ def __init__(self, config):
223
+ super().__init__()
224
+ scale_hidden = getattr(config, "scale_hidden", 1)
225
+ if config.pooling_head == "mean":
226
+ self.mean_pooling = MeanPooling()
227
+ elif config.pooling_head == "attention":
228
+ self.mean_pooling = Attention1dPooling(config)
229
+ self.dense = nn.Linear(
230
+ config.pooler_hidden_size, scale_hidden * config.pooler_hidden_size
231
+ )
232
+ self.dropout = nn.Dropout(config.pooler_dropout)
233
+ self.config = config
234
+
235
+ def forward(self, hidden_states, input_mask=None):
236
+ # We "pool" the model by simply taking the hidden state corresponding
237
+ # to the first token.
238
+
239
+ context_token = self.mean_pooling(hidden_states, input_mask)
240
+ context_token = self.dropout(context_token)
241
+ pooled_output = self.dense(context_token)
242
+ pooled_output = torch.tanh(pooled_output)
243
+ return pooled_output
244
+
245
+ @property
246
+ def output_dim(self):
247
+ return self.config.hidden_size
248
+
249
+
250
+ class ProSSTLayerNorm(nn.Module):
251
+ """LayerNorm module in the TF style (epsilon inside the square root)."""
252
+
253
+ def __init__(self, size, eps=1e-12):
254
+ super().__init__()
255
+ self.weight = nn.Parameter(torch.ones(size))
256
+ self.bias = nn.Parameter(torch.zeros(size))
257
+ self.variance_epsilon = eps
258
+
259
+ def forward(self, hidden_states):
260
+ input_type = hidden_states.dtype
261
+ hidden_states = hidden_states.float()
262
+ mean = hidden_states.mean(-1, keepdim=True)
263
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
264
+ hidden_states = (hidden_states - mean) / torch.sqrt(
265
+ variance + self.variance_epsilon
266
+ )
267
+ hidden_states = hidden_states.to(input_type)
268
+ y = self.weight * hidden_states + self.bias
269
+ return y
270
+
271
+
272
+ class DisentangledSelfAttention(nn.Module):
273
+
274
+ def __init__(self, config: ProSSTConfig):
275
+ super().__init__()
276
+ if config.hidden_size % config.num_attention_heads != 0:
277
+ raise ValueError(
278
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
279
+ f"heads ({config.num_attention_heads})"
280
+ )
281
+ self.num_attention_heads = config.num_attention_heads
282
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
283
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
284
+
285
+ # Q, K, V projection layers
286
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
287
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
288
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
289
+
290
+ # AA->SS, AA->POS, SS->AA, POS->AA and AA->AA attention layers
291
+ self.pos_att_type = (
292
+ config.pos_att_type if config.pos_att_type is not None else []
293
+ )
294
+
295
+ self.relative_attention = getattr(config, "relative_attention", False)
296
+ self.position_embedding_type = getattr(
297
+ config, "position_embedding_type", "relative"
298
+ )
299
+ if self.position_embedding_type == "rotary":
300
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
301
+ if self.relative_attention:
302
+
303
+ if "aa2ss" in self.pos_att_type:
304
+ self.ss_proj = nn.Linear(
305
+ config.hidden_size, self.all_head_size, bias=False
306
+ )
307
+
308
+ if "ss2aa" in self.pos_att_type:
309
+ self.ss_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
310
+
311
+ elif self.position_embedding_type == "relative":
312
+ if self.relative_attention:
313
+ self.max_relative_positions = getattr(
314
+ config, "max_relative_positions", -1
315
+ )
316
+ if self.max_relative_positions < 1:
317
+ self.max_relative_positions = config.max_position_embeddings
318
+ self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
319
+
320
+ # amino acid to position
321
+ if "aa2pos" in self.pos_att_type:
322
+ self.pos_proj = nn.Linear(
323
+ config.hidden_size, self.all_head_size, bias=False
324
+ ) # Key
325
+
326
+ if "pos2aa" in self.pos_att_type:
327
+ self.pos_q_proj = nn.Linear(
328
+ config.hidden_size, self.all_head_size
329
+ ) # Query
330
+
331
+ if "aa2ss" in self.pos_att_type:
332
+ self.ss_proj = nn.Linear(
333
+ config.hidden_size, self.all_head_size, bias=False
334
+ )
335
+
336
+ if "ss2aa" in self.pos_att_type:
337
+ self.ss_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
338
+
339
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
340
+
341
+ def transpose_for_scores(self, x):
342
+ # x [batch_size, seq_len, all_head_size]
343
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
344
+ # x [batch_size, seq_len, num_attention_heads, attention_head_size]
345
+ x = x.view(new_x_shape)
346
+ # x [batch_size, num_attention_heads, seq_len, attention_head_size]
347
+ return x.permute(0, 2, 1, 3)
348
+
349
+ def forward(
350
+ self,
351
+ hidden_states,
352
+ attention_mask,
353
+ output_attentions=False,
354
+ query_states=None,
355
+ relative_pos=None,
356
+ rel_embeddings=None,
357
+ ss_hidden_states=None,
358
+ ):
359
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
360
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
361
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
362
+
363
+ if self.position_embedding_type == "rotary":
364
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
365
+
366
+ rel_att = None
367
+ scale_factor = 1 + len(self.pos_att_type)
368
+ scale = torch.sqrt(
369
+ torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor
370
+ )
371
+ query_layer = query_layer / scale.to(dtype=query_layer.dtype)
372
+
373
+ # [batch_size, num_attention_heads, seq_len, seq_len]
374
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
375
+
376
+ if self.relative_attention:
377
+ if self.position_embedding_type == "relative":
378
+ rel_embeddings = self.pos_dropout(rel_embeddings)
379
+ rel_att = self.disentangled_att_bias(
380
+ query_layer,
381
+ key_layer,
382
+ relative_pos,
383
+ rel_embeddings,
384
+ scale_factor,
385
+ ss_hidden_states,
386
+ )
387
+
388
+ if rel_att is not None:
389
+ attention_scores = attention_scores + rel_att
390
+
391
+ rmask = ~(attention_mask.to(torch.bool))
392
+ attention_probs = attention_scores.masked_fill(rmask, float("-inf"))
393
+ attention_probs = torch.softmax(attention_probs, -1)
394
+ attention_probs = attention_probs.masked_fill(rmask, 0.0)
395
+ # attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
396
+ attention_probs = self.dropout(attention_probs)
397
+
398
+ context_layer = torch.matmul(attention_probs, value_layer)
399
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
400
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
401
+ context_layer = context_layer.view(new_context_layer_shape)
402
+ if output_attentions:
403
+ return (context_layer, attention_probs)
404
+ else:
405
+ return context_layer
406
+
407
+ def disentangled_att_bias(
408
+ self,
409
+ query_layer,
410
+ key_layer,
411
+ relative_pos,
412
+ rel_embeddings,
413
+ scale_factor,
414
+ ss_hidden_states,
415
+ ):
416
+ if self.position_embedding_type == "relative":
417
+ if relative_pos is None:
418
+ q = query_layer.size(-2)
419
+ relative_pos = build_relative_position(
420
+ q, key_layer.size(-2), query_layer.device
421
+ )
422
+ if relative_pos.dim() == 2:
423
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
424
+ elif relative_pos.dim() == 3:
425
+ relative_pos = relative_pos.unsqueeze(1)
426
+ # bxhxqxk
427
+ elif relative_pos.dim() != 4:
428
+ raise ValueError(
429
+ f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}"
430
+ )
431
+
432
+ att_span = min(
433
+ max(query_layer.size(-2), key_layer.size(-2)),
434
+ self.max_relative_positions,
435
+ )
436
+ relative_pos = relative_pos.long().to(query_layer.device)
437
+ rel_embeddings = rel_embeddings[
438
+ self.max_relative_positions
439
+ - att_span : self.max_relative_positions
440
+ + att_span,
441
+ :,
442
+ ].unsqueeze(0)
443
+
444
+ score = 0
445
+
446
+ if "aa2pos" in self.pos_att_type:
447
+ pos_key_layer = self.pos_proj(rel_embeddings)
448
+ pos_key_layer = self.transpose_for_scores(pos_key_layer)
449
+ aa2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
450
+ aa2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
451
+ aa2p_att = torch.gather(
452
+ aa2p_att,
453
+ dim=-1,
454
+ index=c2p_dynamic_expand(aa2p_pos, query_layer, relative_pos),
455
+ )
456
+ score += aa2p_att
457
+
458
+ if "pos2aa" in self.pos_att_type:
459
+ pos_query_layer = self.pos_q_proj(rel_embeddings)
460
+ pos_query_layer = self.transpose_for_scores(pos_query_layer)
461
+ pos_query_layer /= torch.sqrt(
462
+ torch.tensor(pos_query_layer.size(-1), dtype=torch.float)
463
+ * scale_factor
464
+ )
465
+ if query_layer.size(-2) != key_layer.size(-2):
466
+ r_pos = build_relative_position(
467
+ key_layer.size(-2), key_layer.size(-2), query_layer.device
468
+ )
469
+ else:
470
+ r_pos = relative_pos
471
+ p2aa_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
472
+ p2aa_att = torch.matmul(
473
+ key_layer,
474
+ pos_query_layer.transpose(-1, -2).to(dtype=key_layer.dtype),
475
+ )
476
+ p2aa_att = torch.gather(
477
+ p2aa_att,
478
+ dim=-1,
479
+ index=p2c_dynamic_expand(p2aa_pos, query_layer, key_layer),
480
+ ).transpose(-1, -2)
481
+
482
+ if query_layer.size(-2) != key_layer.size(-2):
483
+ pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
484
+ p2aa_att = torch.gather(
485
+ p2aa_att,
486
+ dim=-2,
487
+ index=pos_dynamic_expand(pos_index, p2aa_att, key_layer),
488
+ )
489
+ score += p2aa_att
490
+
491
+ # content -> structure
492
+ if "aa2ss" in self.pos_att_type:
493
+ assert ss_hidden_states is not None
494
+ ss_key_layer = self.ss_proj(ss_hidden_states)
495
+ ss_key_layer = self.transpose_for_scores(ss_key_layer)
496
+ # [batch_size, num_attention_heads, seq_len, seq_len]
497
+ aa2ss_att = torch.matmul(query_layer, ss_key_layer.transpose(-1, -2))
498
+ score += aa2ss_att
499
+
500
+ if "ss2aa" in self.pos_att_type:
501
+ assert ss_hidden_states is not None
502
+ ss_query_layer = self.ss_q_proj(ss_hidden_states)
503
+ ss_query_layer = self.transpose_for_scores(ss_query_layer)
504
+ ss_query_layer /= torch.sqrt(
505
+ torch.tensor(ss_query_layer.size(-1), dtype=torch.float)
506
+ * scale_factor
507
+ )
508
+ ss2aa_att = torch.matmul(
509
+ key_layer, query_layer.transpose(-1, -2).to(dtype=key_layer.dtype)
510
+ )
511
+ score += ss2aa_att
512
+ return score
513
+ elif self.position_embedding_type == "rotary":
514
+ score = 0
515
+ if "aa2ss" in self.pos_att_type:
516
+ assert ss_hidden_states is not None
517
+ ss_key_layer = self.ss_proj(ss_hidden_states)
518
+ ss_key_layer = self.transpose_for_scores(ss_key_layer)
519
+ aa2ss_att = torch.matmul(query_layer, ss_key_layer.transpose(-1, -2))
520
+ score += aa2ss_att
521
+
522
+ if "ss2aa" in self.pos_att_type:
523
+ assert ss_hidden_states is not None
524
+ ss_query_layer = self.ss_q_proj(ss_hidden_states)
525
+ ss_query_layer = self.transpose_for_scores(ss_query_layer)
526
+ ss_query_layer /= torch.sqrt(
527
+ torch.tensor(ss_query_layer.size(-1), dtype=torch.float)
528
+ * scale_factor
529
+ )
530
+ ss2aa_att = torch.matmul(
531
+ key_layer, query_layer.transpose(-1, -2).to(dtype=key_layer.dtype)
532
+ )
533
+ score += ss2aa_att
534
+ return score
535
+
536
+
537
+ class ProSSTSelfOutput(nn.Module):
538
+ def __init__(self, config):
539
+ super().__init__()
540
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
541
+ self.LayerNorm = ProSSTLayerNorm(config.hidden_size, config.layer_norm_eps)
542
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
543
+
544
+ def forward(self, hidden_states, input_tensor):
545
+ hidden_states = self.dense(hidden_states)
546
+ hidden_states = self.dropout(hidden_states)
547
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
548
+ return hidden_states
549
+
550
+
551
+ class ProSSTAttention(nn.Module):
552
+ def __init__(self, config):
553
+ super().__init__()
554
+ self.self = DisentangledSelfAttention(config)
555
+ self.output = ProSSTSelfOutput(config)
556
+ self.config = config
557
+
558
+ def forward(
559
+ self,
560
+ hidden_states,
561
+ attention_mask,
562
+ output_attentions=False,
563
+ query_states=None,
564
+ relative_pos=None,
565
+ rel_embeddings=None,
566
+ ss_hidden_states=None,
567
+ ):
568
+ self_output = self.self(
569
+ hidden_states,
570
+ attention_mask,
571
+ output_attentions,
572
+ query_states=query_states,
573
+ relative_pos=relative_pos,
574
+ rel_embeddings=rel_embeddings,
575
+ ss_hidden_states=ss_hidden_states,
576
+ )
577
+ if output_attentions:
578
+ self_output, att_matrix = self_output
579
+ if query_states is None:
580
+ query_states = hidden_states
581
+ attention_output = self.output(self_output, query_states)
582
+
583
+ if output_attentions:
584
+ return (attention_output, att_matrix)
585
+ else:
586
+ return attention_output
587
+
588
+
589
+ class ProSSTIntermediate(nn.Module):
590
+ def __init__(self, config):
591
+ super().__init__()
592
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
593
+ if isinstance(config.hidden_act, str):
594
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
595
+ else:
596
+ self.intermediate_act_fn = config.hidden_act
597
+
598
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
599
+ hidden_states = self.dense(hidden_states)
600
+ hidden_states = self.intermediate_act_fn(hidden_states)
601
+ return hidden_states
602
+
603
+
604
+ class ProSSTOutput(nn.Module):
605
+ def __init__(self, config):
606
+ super().__init__()
607
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
608
+ self.LayerNorm = ProSSTLayerNorm(config.hidden_size, config.layer_norm_eps)
609
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
610
+ self.config = config
611
+
612
+ def forward(self, hidden_states, input_tensor):
613
+ hidden_states = self.dense(hidden_states)
614
+ hidden_states = self.dropout(hidden_states)
615
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
616
+ return hidden_states
617
+
618
+
619
+ class ProSSTLayer(nn.Module):
620
+ def __init__(self, config):
621
+ super().__init__()
622
+ self.attention = ProSSTAttention(config)
623
+ self.intermediate = ProSSTIntermediate(config)
624
+ self.output = ProSSTOutput(config)
625
+
626
+ def forward(
627
+ self,
628
+ hidden_states,
629
+ attention_mask,
630
+ query_states=None,
631
+ relative_pos=None,
632
+ rel_embeddings=None,
633
+ output_attentions=False,
634
+ ss_hidden_states=None,
635
+ ):
636
+ attention_output = self.attention(
637
+ hidden_states,
638
+ attention_mask,
639
+ output_attentions=output_attentions,
640
+ query_states=query_states,
641
+ relative_pos=relative_pos,
642
+ rel_embeddings=rel_embeddings,
643
+ ss_hidden_states=ss_hidden_states,
644
+ )
645
+ if output_attentions:
646
+ attention_output, att_matrix = attention_output
647
+ intermediate_output = self.intermediate(attention_output)
648
+ layer_output = self.output(intermediate_output, attention_output)
649
+ if output_attentions:
650
+ return (layer_output, att_matrix)
651
+ else:
652
+ return layer_output
653
+
654
+
655
+ class ProSSTEncoder(nn.Module):
656
+ """Modified BertEncoder with relative position bias support"""
657
+
658
+ def __init__(self, config):
659
+ super().__init__()
660
+ self.layer = nn.ModuleList(
661
+ [ProSSTLayer(config) for _ in range(config.num_hidden_layers)]
662
+ )
663
+ self.relative_attention = getattr(config, "relative_attention", False)
664
+ if self.relative_attention:
665
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
666
+ if self.max_relative_positions < 1:
667
+ self.max_relative_positions = config.max_position_embeddings
668
+ self.rel_embeddings = nn.Embedding(
669
+ self.max_relative_positions * 2, config.hidden_size
670
+ )
671
+ self.gradient_checkpointing = False
672
+
673
+ def get_rel_embedding(self):
674
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
675
+ return rel_embeddings
676
+
677
+ def get_attention_mask(self, attention_mask):
678
+ if attention_mask.dim() <= 2:
679
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
680
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(
681
+ -2
682
+ ).unsqueeze(-1)
683
+ elif attention_mask.dim() == 3:
684
+ attention_mask = attention_mask.unsqueeze(1)
685
+
686
+ return attention_mask
687
+
688
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
689
+ if self.relative_attention and relative_pos is None:
690
+ q = (
691
+ query_states.size(-2)
692
+ if query_states is not None
693
+ else hidden_states.size(-2)
694
+ )
695
+ relative_pos = build_relative_position(
696
+ q, hidden_states.size(-2), hidden_states.device
697
+ )
698
+ return relative_pos
699
+
700
+ def forward(
701
+ self,
702
+ hidden_states,
703
+ attention_mask,
704
+ output_hidden_states=True,
705
+ output_attentions=False,
706
+ query_states=None,
707
+ relative_pos=None,
708
+ ss_hidden_states=None,
709
+ return_dict=True,
710
+ ):
711
+ attention_mask = self.get_attention_mask(attention_mask)
712
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
713
+
714
+ all_hidden_states = () if output_hidden_states else None
715
+ all_attentions = () if output_attentions else None
716
+
717
+ if isinstance(hidden_states, Sequence):
718
+ next_kv = hidden_states[0]
719
+ else:
720
+ next_kv = hidden_states
721
+ rel_embeddings = self.get_rel_embedding()
722
+ for i, layer_module in enumerate(self.layer):
723
+ if output_hidden_states:
724
+ all_hidden_states = all_hidden_states + (hidden_states,)
725
+
726
+ if self.gradient_checkpointing and self.training:
727
+
728
+ def create_custom_forward(module):
729
+ def custom_forward(*inputs):
730
+ return module(*inputs, output_attentions)
731
+
732
+ return custom_forward
733
+
734
+ hidden_states = torch.utils.checkpoint.checkpoint(
735
+ create_custom_forward(layer_module),
736
+ next_kv,
737
+ attention_mask,
738
+ query_states,
739
+ relative_pos,
740
+ rel_embeddings,
741
+ ss_hidden_states,
742
+ )
743
+ else:
744
+ hidden_states = layer_module(
745
+ next_kv,
746
+ attention_mask,
747
+ query_states=query_states,
748
+ relative_pos=relative_pos,
749
+ rel_embeddings=rel_embeddings,
750
+ output_attentions=output_attentions,
751
+ ss_hidden_states=ss_hidden_states,
752
+ )
753
+
754
+ if output_attentions:
755
+ hidden_states, att_m = hidden_states
756
+
757
+ if query_states is not None:
758
+ query_states = hidden_states
759
+ if isinstance(hidden_states, Sequence):
760
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
761
+ else:
762
+ next_kv = hidden_states
763
+
764
+ if output_attentions:
765
+ all_attentions = all_attentions + (att_m,)
766
+
767
+ if output_hidden_states:
768
+ all_hidden_states = all_hidden_states + (hidden_states,)
769
+
770
+ if not return_dict:
771
+ return tuple(
772
+ v
773
+ for v in [hidden_states, all_hidden_states, all_attentions]
774
+ if v is not None
775
+ )
776
+ return BaseModelOutput(
777
+ last_hidden_state=hidden_states,
778
+ hidden_states=all_hidden_states,
779
+ attentions=all_attentions,
780
+ )
781
+
782
+
783
+ class ProSSTEmbeddings(nn.Module):
784
+ """Construct the embeddings from word, position and token_type embeddings."""
785
+
786
+ def __init__(self, config):
787
+ super().__init__()
788
+ pad_token_id = getattr(config, "pad_token_id", 0)
789
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
790
+ self.word_embeddings = nn.Embedding(
791
+ config.vocab_size, self.embedding_size, padding_idx=pad_token_id
792
+ )
793
+
794
+ self.position_biased_input = getattr(config, "position_biased_input", False)
795
+ if not self.position_biased_input:
796
+ self.position_embeddings = None
797
+ else:
798
+ # assert getattr(config, "position_embedding_type", "relative") == "absolute"
799
+ self.position_embeddings = nn.Embedding(
800
+ config.max_position_embeddings, self.embedding_size
801
+ )
802
+
803
+ if config.type_vocab_size > 0:
804
+ self.token_type_embeddings = nn.Embedding(
805
+ config.type_vocab_size, self.embedding_size
806
+ )
807
+
808
+ if config.ss_vocab_size > 0:
809
+ self.ss_embeddings = nn.Embedding(config.ss_vocab_size, self.embedding_size)
810
+ self.ss_layer_norm = ProSSTLayerNorm(
811
+ config.hidden_size, config.layer_norm_eps
812
+ )
813
+
814
+ if self.embedding_size != config.hidden_size:
815
+ self.embed_proj = nn.Linear(
816
+ self.embedding_size, config.hidden_size, bias=False
817
+ )
818
+ self.LayerNorm = ProSSTLayerNorm(config.hidden_size, config.layer_norm_eps)
819
+
820
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
821
+ self.config = config
822
+
823
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
824
+ if self.position_biased_input:
825
+ self.register_buffer(
826
+ "position_ids",
827
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
828
+ persistent=False,
829
+ )
830
+
831
+ def forward(
832
+ self,
833
+ input_ids=None,
834
+ ss_input_ids=None,
835
+ token_type_ids=None,
836
+ position_ids=None,
837
+ mask=None,
838
+ inputs_embeds=None,
839
+ ):
840
+ if input_ids is not None:
841
+ input_shape = input_ids.size()
842
+ else:
843
+ input_shape = inputs_embeds.size()[:-1]
844
+
845
+ seq_length = input_shape[1]
846
+
847
+ if position_ids is None and self.position_biased_input:
848
+ position_ids = self.position_ids[:, :seq_length]
849
+ if seq_length > position_ids.size(1):
850
+ zero_padding = (
851
+ torch.zeros(
852
+ (input_shape[0], seq_length - position_ids.size(1)),
853
+ dtype=torch.long,
854
+ device=position_ids.device,
855
+ )
856
+ + 2047
857
+ )
858
+ position_ids = torch.cat([position_ids, zero_padding], dim=1)
859
+
860
+ if token_type_ids is None:
861
+ token_type_ids = torch.zeros(
862
+ input_shape, dtype=torch.long, device=self.position_ids.device
863
+ )
864
+
865
+ if inputs_embeds is None:
866
+ if self.config.token_dropout:
867
+ inputs_embeds = self.word_embeddings(input_ids)
868
+ inputs_embeds.masked_fill_(
869
+ (input_ids == self.config.mask_token_id).unsqueeze(-1), 0.0
870
+ )
871
+ mask_ratio_train = self.config.mlm_probability * 0.8
872
+ src_lengths = mask.sum(dim=-1)
873
+ mask_ratio_observed = (input_ids == self.config.mask_token_id).sum(
874
+ -1
875
+ ).to(inputs_embeds.dtype) / src_lengths
876
+ inputs_embeds = (
877
+ inputs_embeds
878
+ * (1 - mask_ratio_train)
879
+ / (1 - mask_ratio_observed)[:, None, None]
880
+ )
881
+ else:
882
+ inputs_embeds = self.word_embeddings(input_ids)
883
+
884
+ if self.position_embeddings is not None and self.position_biased_input:
885
+ position_embeddings = self.position_embeddings(position_ids.long())
886
+ else:
887
+ position_embeddings = torch.zeros_like(inputs_embeds)
888
+
889
+ embeddings = inputs_embeds
890
+ if self.position_biased_input:
891
+ embeddings += position_embeddings
892
+ if self.config.type_vocab_size > 0:
893
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
894
+ embeddings += token_type_embeddings
895
+
896
+ if self.embedding_size != self.config.hidden_size:
897
+ embeddings = self.embed_proj(embeddings)
898
+
899
+ embeddings = self.LayerNorm(embeddings)
900
+
901
+ if mask is not None:
902
+ if mask.dim() != embeddings.dim():
903
+ if mask.dim() == 4:
904
+ mask = mask.squeeze(1).squeeze(1)
905
+ mask = mask.unsqueeze(2)
906
+ mask = mask.to(embeddings.dtype)
907
+ embeddings = embeddings * mask
908
+
909
+ embeddings = self.dropout(embeddings)
910
+
911
+ if self.config.ss_vocab_size > 0:
912
+ ss_embeddings = self.ss_embeddings(ss_input_ids)
913
+ ss_embeddings = self.ss_layer_norm(ss_embeddings)
914
+ if mask is not None:
915
+ if mask.dim() != ss_embeddings.dim():
916
+ if mask.dim() == 4:
917
+ mask = mask.squeeze(1).squeeze(1)
918
+ mask = mask.unsqueeze(2)
919
+ mask = mask.to(ss_embeddings.dtype)
920
+ ss_embeddings = ss_embeddings * mask
921
+ ss_embeddings = self.dropout(ss_embeddings)
922
+ return embeddings, ss_embeddings
923
+
924
+ return embeddings, None
925
+
926
+
927
+ class ProSSTPreTrainedModel(PreTrainedModel):
928
+ """
929
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
930
+ models.
931
+ """
932
+
933
+ config_class = ProSSTConfig
934
+ base_model_prefix = "ProSST"
935
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
936
+ supports_gradient_checkpointing = True
937
+
938
+ def _init_weights(self, module):
939
+ """Initialize the weights."""
940
+ if isinstance(module, nn.Linear):
941
+ # Slightly different from the TF version which uses truncated_normal for initialization
942
+ # cf https://github.com/pytorch/pytorch/pull/5617
943
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
944
+ if module.bias is not None:
945
+ module.bias.data.zero_()
946
+ elif isinstance(module, nn.Embedding):
947
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
948
+ if module.padding_idx is not None:
949
+ module.weight.data[module.padding_idx].zero_()
950
+
951
+ def _set_gradient_checkpointing(self, module, value=False):
952
+ if isinstance(module, ProSSTEncoder):
953
+ module.gradient_checkpointing = value
954
+
955
+
956
+ class ProSSTModel(ProSSTPreTrainedModel):
957
+ def __init__(self, config):
958
+ super().__init__(config)
959
+
960
+ self.embeddings = ProSSTEmbeddings(config)
961
+ self.encoder = ProSSTEncoder(config)
962
+ self.config = config
963
+ # Initialize weights and apply final processing
964
+ self.post_init()
965
+
966
+ def get_input_embeddings(self):
967
+ return self.embeddings.word_embeddings
968
+
969
+ def set_input_embeddings(self, new_embeddings):
970
+ self.embeddings.word_embeddings = new_embeddings
971
+
972
+ def _prune_heads(self, heads_to_prune):
973
+ """
974
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
975
+ class PreTrainedModel
976
+ """
977
+ raise NotImplementedError(
978
+ "The prune function is not implemented in DeBERTa model."
979
+ )
980
+
981
+ def forward(
982
+ self,
983
+ input_ids: Optional[torch.Tensor] = None,
984
+ ss_input_ids: Optional[torch.Tensor] = None,
985
+ attention_mask: Optional[torch.Tensor] = None,
986
+ token_type_ids: Optional[torch.Tensor] = None,
987
+ position_ids: Optional[torch.Tensor] = None,
988
+ inputs_embeds: Optional[torch.Tensor] = None,
989
+ output_attentions: Optional[bool] = None,
990
+ output_hidden_states: Optional[bool] = None,
991
+ return_dict: Optional[bool] = None,
992
+ ) -> Union[Tuple, BaseModelOutput]:
993
+ output_attentions = (
994
+ output_attentions
995
+ if output_attentions is not None
996
+ else self.config.output_attentions
997
+ )
998
+ output_hidden_states = (
999
+ output_hidden_states
1000
+ if output_hidden_states is not None
1001
+ else self.config.output_hidden_states
1002
+ )
1003
+ return_dict = (
1004
+ return_dict if return_dict is not None else self.config.use_return_dict
1005
+ )
1006
+
1007
+ if input_ids is not None and inputs_embeds is not None:
1008
+ raise ValueError(
1009
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1010
+ )
1011
+ elif input_ids is not None:
1012
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1013
+ input_shape = input_ids.size()
1014
+ elif inputs_embeds is not None:
1015
+ input_shape = inputs_embeds.size()[:-1]
1016
+ else:
1017
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1018
+
1019
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1020
+
1021
+ if attention_mask is None:
1022
+ attention_mask = torch.ones(input_shape, device=device)
1023
+ if token_type_ids is None:
1024
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1025
+
1026
+ embedding_output, ss_embeddings = self.embeddings(
1027
+ input_ids=input_ids,
1028
+ ss_input_ids=ss_input_ids,
1029
+ token_type_ids=token_type_ids,
1030
+ position_ids=position_ids,
1031
+ mask=attention_mask,
1032
+ inputs_embeds=inputs_embeds,
1033
+ )
1034
+
1035
+ encoder_outputs = self.encoder(
1036
+ embedding_output,
1037
+ attention_mask,
1038
+ output_hidden_states=True,
1039
+ output_attentions=output_attentions,
1040
+ return_dict=return_dict,
1041
+ ss_hidden_states=ss_embeddings,
1042
+ )
1043
+ encoded_layers = encoder_outputs[1]
1044
+
1045
+ sequence_output = encoded_layers[-1]
1046
+
1047
+ if not return_dict:
1048
+ return (sequence_output,) + encoder_outputs[
1049
+ (1 if output_hidden_states else 2) :
1050
+ ]
1051
+
1052
+ return BaseModelOutput(
1053
+ last_hidden_state=sequence_output,
1054
+ hidden_states=(
1055
+ encoder_outputs.hidden_states if output_hidden_states else None
1056
+ ),
1057
+ attentions=encoder_outputs.attentions,
1058
+ )
1059
+
1060
+
1061
+ class ProSSTPredictionHeadTransform(nn.Module):
1062
+ def __init__(self, config):
1063
+ super().__init__()
1064
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
1065
+
1066
+ self.dense = nn.Linear(config.hidden_size, self.embedding_size)
1067
+ if isinstance(config.hidden_act, str):
1068
+ self.transform_act_fn = ACT2FN[config.hidden_act]
1069
+ else:
1070
+ self.transform_act_fn = config.hidden_act
1071
+ self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
1072
+
1073
+ def forward(self, hidden_states):
1074
+ hidden_states = self.dense(hidden_states)
1075
+ hidden_states = self.transform_act_fn(hidden_states)
1076
+ hidden_states = self.LayerNorm(hidden_states)
1077
+ return hidden_states
1078
+
1079
+
1080
+ class ProSSTLMPredictionHead(nn.Module):
1081
+ def __init__(self, config):
1082
+ super().__init__()
1083
+ self.transform = ProSSTPredictionHeadTransform(config)
1084
+
1085
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
1086
+ # The output weights are the same as the input embeddings, but there is
1087
+ # an output-only bias for each token.
1088
+ self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False)
1089
+
1090
+ # self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1091
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
1092
+ # self.decoder.bias = self.bias
1093
+
1094
+ def forward(self, hidden_states):
1095
+ hidden_states = self.transform(hidden_states)
1096
+ hidden_states = self.decoder(hidden_states)
1097
+ return hidden_states
1098
+
1099
+
1100
+ class ProSSTOnlyMLMHead(nn.Module):
1101
+ def __init__(self, config):
1102
+ super().__init__()
1103
+ self.predictions = ProSSTLMPredictionHead(config)
1104
+
1105
+ def forward(self, sequence_output):
1106
+ prediction_scores = self.predictions(sequence_output)
1107
+ return prediction_scores
1108
+
1109
+
1110
+ class ProSSTPreTrainedModel(PreTrainedModel):
1111
+ """
1112
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1113
+ models.
1114
+ """
1115
+
1116
+ config_class = ProSSTConfig
1117
+ base_model_prefix = "ProSST"
1118
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
1119
+ supports_gradient_checkpointing = True
1120
+
1121
+ def _init_weights(self, module):
1122
+ """Initialize the weights."""
1123
+ if isinstance(module, nn.Linear):
1124
+ # Slightly different from the TF version which uses truncated_normal for initialization
1125
+ # cf https://github.com/pytorch/pytorch/pull/5617
1126
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1127
+ if module.bias is not None:
1128
+ module.bias.data.zero_()
1129
+ elif isinstance(module, nn.Embedding):
1130
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1131
+ if module.padding_idx is not None:
1132
+ module.weight.data[module.padding_idx].zero_()
1133
+
1134
+ def _set_gradient_checkpointing(self, module, value=False):
1135
+ if isinstance(module, ProSSTEncoder):
1136
+ module.gradient_checkpointing = value
1137
+
1138
+
1139
+ class ProSSTForMaskedLM(ProSSTPreTrainedModel):
1140
+ _tied_weights_keys = [
1141
+ "cls.predictions.decoder.weight",
1142
+ "cls.predictions.decoder.bias",
1143
+ ]
1144
+
1145
+ def __init__(self, config):
1146
+ super().__init__(config)
1147
+
1148
+ self.prosst = ProSSTModel(config)
1149
+ self.cls = ProSSTOnlyMLMHead(config)
1150
+
1151
+ # Initialize weights and apply final processing
1152
+ self.post_init()
1153
+
1154
+ def get_input_embeddings(self):
1155
+ return self.prosst.embeddings.word_embeddings
1156
+
1157
+ def get_output_embeddings(self):
1158
+ return self.cls.predictions.decoder
1159
+
1160
+ def set_output_embeddings(self, new_embeddings):
1161
+ self.cls.predictions.decoder = new_embeddings
1162
+
1163
+ def forward(
1164
+ self,
1165
+ input_ids: Optional[torch.Tensor] = None,
1166
+ ss_input_ids: Optional[torch.Tensor] = None,
1167
+ attention_mask: Optional[torch.Tensor] = None,
1168
+ token_type_ids: Optional[torch.Tensor] = None,
1169
+ position_ids: Optional[torch.Tensor] = None,
1170
+ inputs_embeds: Optional[torch.Tensor] = None,
1171
+ labels: Optional[torch.Tensor] = None,
1172
+ output_attentions: Optional[bool] = None,
1173
+ output_hidden_states: Optional[bool] = None,
1174
+ return_dict: Optional[bool] = None,
1175
+ ) -> Union[Tuple, MaskedLMOutput]:
1176
+ r"""
1177
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1178
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1179
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1180
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1181
+ """
1182
+
1183
+ return_dict = (
1184
+ return_dict if return_dict is not None else self.config.use_return_dict
1185
+ )
1186
+
1187
+ outputs = self.prosst(
1188
+ input_ids,
1189
+ ss_input_ids=ss_input_ids,
1190
+ attention_mask=attention_mask,
1191
+ token_type_ids=token_type_ids,
1192
+ position_ids=position_ids,
1193
+ inputs_embeds=inputs_embeds,
1194
+ output_attentions=output_attentions,
1195
+ output_hidden_states=output_hidden_states,
1196
+ return_dict=return_dict,
1197
+ )
1198
+
1199
+ sequence_output = outputs[0]
1200
+ prediction_scores = self.cls(sequence_output)
1201
+
1202
+ masked_lm_loss = None
1203
+ if labels is not None:
1204
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1205
+ masked_lm_loss = loss_fct(
1206
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1207
+ )
1208
+
1209
+ if not return_dict:
1210
+ output = (prediction_scores,) + outputs[1:]
1211
+ return (
1212
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1213
+ )
1214
+
1215
+ return MaskedLMOutput(
1216
+ loss=masked_lm_loss,
1217
+ logits=prediction_scores,
1218
+ hidden_states=outputs.hidden_states,
1219
+ attentions=outputs.attentions,
1220
+ )
1221
+
1222
+
1223
+ class ProSSTForSequenceClassification(ProSSTPreTrainedModel):
1224
+ def __init__(self, config):
1225
+ super().__init__(config)
1226
+
1227
+ num_labels = getattr(config, "num_labels", 2)
1228
+ self.num_labels = num_labels
1229
+ self.scale_hidden = getattr(config, "scale_hidden", 1)
1230
+ self.prosst = ProSSTModel(config)
1231
+ self.pooler = ContextPooler(config)
1232
+ output_dim = self.pooler.output_dim * self.scale_hidden
1233
+
1234
+ self.classifier = nn.Linear(output_dim, num_labels)
1235
+ drop_out = getattr(config, "cls_dropout", None)
1236
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1237
+ self.dropout = nn.Dropout(drop_out)
1238
+
1239
+ # Initialize weights and apply final processing
1240
+ self.post_init()
1241
+
1242
+ def get_input_embeddings(self):
1243
+ return self.prosst.get_input_embeddings()
1244
+
1245
+ def set_input_embeddings(self, new_embeddings):
1246
+ self.prosst.set_input_embeddings(new_embeddings)
1247
+
1248
+ def forward(
1249
+ self,
1250
+ input_ids: Optional[torch.Tensor] = None,
1251
+ ss_input_ids: Optional[torch.Tensor] = None,
1252
+ attention_mask: Optional[torch.Tensor] = None,
1253
+ token_type_ids: Optional[torch.Tensor] = None,
1254
+ position_ids: Optional[torch.Tensor] = None,
1255
+ inputs_embeds: Optional[torch.Tensor] = None,
1256
+ labels: Optional[torch.Tensor] = None,
1257
+ output_attentions: Optional[bool] = None,
1258
+ output_hidden_states: Optional[bool] = None,
1259
+ return_dict: Optional[bool] = None,
1260
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1261
+ r"""
1262
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1263
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1264
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1265
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1266
+ """
1267
+ return_dict = (
1268
+ return_dict if return_dict is not None else self.config.use_return_dict
1269
+ )
1270
+
1271
+ outputs = self.prosst(
1272
+ input_ids,
1273
+ ss_input_ids=ss_input_ids,
1274
+ token_type_ids=token_type_ids,
1275
+ attention_mask=attention_mask,
1276
+ position_ids=position_ids,
1277
+ inputs_embeds=inputs_embeds,
1278
+ output_attentions=output_attentions,
1279
+ output_hidden_states=output_hidden_states,
1280
+ return_dict=return_dict,
1281
+ )
1282
+
1283
+ encoder_layer = outputs[0]
1284
+ pooled_output = self.pooler(encoder_layer, attention_mask)
1285
+ pooled_output = self.dropout(pooled_output)
1286
+ logits = self.classifier(pooled_output)
1287
+
1288
+ loss = None
1289
+ if labels is not None:
1290
+ if self.config.problem_type is None:
1291
+ if self.num_labels == 1:
1292
+ # regression task
1293
+ loss_fn = nn.MSELoss()
1294
+ logits = logits.view(-1).to(labels.dtype)
1295
+ loss = loss_fn(logits, labels.view(-1))
1296
+ elif labels.dim() == 1 or labels.size(-1) == 1:
1297
+ label_index = (labels >= 0).nonzero()
1298
+ labels = labels.long()
1299
+ if label_index.size(0) > 0:
1300
+ labeled_logits = torch.gather(
1301
+ logits,
1302
+ 0,
1303
+ label_index.expand(label_index.size(0), logits.size(1)),
1304
+ )
1305
+ labels = torch.gather(labels, 0, label_index.view(-1))
1306
+ loss_fct = CrossEntropyLoss()
1307
+ loss = loss_fct(
1308
+ labeled_logits.view(-1, self.num_labels).float(),
1309
+ labels.view(-1),
1310
+ )
1311
+ else:
1312
+ loss = torch.tensor(0).to(logits)
1313
+ else:
1314
+ log_softmax = nn.LogSoftmax(-1)
1315
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1316
+ elif self.config.problem_type == "regression":
1317
+ loss_fct = MSELoss()
1318
+ if self.num_labels == 1:
1319
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1320
+ else:
1321
+ loss = loss_fct(logits, labels)
1322
+ elif self.config.problem_type == "binary_classification":
1323
+ loss_fct = BCEWithLogitsLoss()
1324
+ loss = loss_fct(logits.squeeze(), labels.squeeze().to(logits.dtype))
1325
+ elif self.config.problem_type == "single_label_classification":
1326
+ loss_fct = CrossEntropyLoss()
1327
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1328
+ elif self.config.problem_type == "multi_label_classification":
1329
+ loss_fct = BCEWithLogitsLoss()
1330
+ loss = loss_fct(logits, labels.to(logits.dtype))
1331
+ if not return_dict:
1332
+ output = (logits,) + outputs[1:]
1333
+ return ((loss,) + output) if loss is not None else output
1334
+
1335
+ return SequenceClassifierOutput(
1336
+ loss=loss,
1337
+ logits=logits,
1338
+ hidden_states=outputs.hidden_states,
1339
+ attentions=outputs.attentions,
1340
+ )
1341
+
1342
+
1343
+ class ProSSTForTokenClassification(ProSSTPreTrainedModel):
1344
+ def __init__(self, config):
1345
+ super().__init__(config)
1346
+ self.num_labels = config.num_labels
1347
+
1348
+ self.prosst = ProSSTModel(config)
1349
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1350
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1351
+
1352
+ # Initialize weights and apply final processing
1353
+ self.post_init()
1354
+
1355
+ def forward(
1356
+ self,
1357
+ input_ids: Optional[torch.Tensor] = None,
1358
+ attention_mask: Optional[torch.Tensor] = None,
1359
+ token_type_ids: Optional[torch.Tensor] = None,
1360
+ position_ids: Optional[torch.Tensor] = None,
1361
+ inputs_embeds: Optional[torch.Tensor] = None,
1362
+ labels: Optional[torch.Tensor] = None,
1363
+ output_attentions: Optional[bool] = None,
1364
+ output_hidden_states: Optional[bool] = None,
1365
+ return_dict: Optional[bool] = None,
1366
+ ) -> Union[Tuple, TokenClassifierOutput]:
1367
+ r"""
1368
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1369
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1370
+ """
1371
+ return_dict = (
1372
+ return_dict if return_dict is not None else self.config.use_return_dict
1373
+ )
1374
+
1375
+ outputs = self.prosst(
1376
+ input_ids,
1377
+ attention_mask=attention_mask,
1378
+ token_type_ids=token_type_ids,
1379
+ position_ids=position_ids,
1380
+ inputs_embeds=inputs_embeds,
1381
+ output_attentions=output_attentions,
1382
+ output_hidden_states=output_hidden_states,
1383
+ return_dict=return_dict,
1384
+ )
1385
+
1386
+ sequence_output = outputs[0]
1387
+
1388
+ sequence_output = self.dropout(sequence_output)
1389
+ logits = self.classifier(sequence_output)
1390
+
1391
+ loss = None
1392
+ if labels is not None:
1393
+ loss_fct = CrossEntropyLoss()
1394
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1395
+
1396
+ if not return_dict:
1397
+ output = (logits,) + outputs[1:]
1398
+ return ((loss,) + output) if loss is not None else output
1399
+
1400
+ return TokenClassifierOutput(
1401
+ loss=loss,
1402
+ logits=logits,
1403
+ hidden_states=outputs.hidden_states,
1404
+ attentions=outputs.attentions,
1405
+ )
1406
+
1407
+
1408
+ ProSSTModel.register_for_auto_class("AutoModel")
1409
+ ProSSTForMaskedLM.register_for_auto_class("AutoModelForMaskedLM")
1410
+ ProSSTForSequenceClassification.register_for_auto_class(
1411
+ "AutoModelForSequenceClassification"
1412
+ )
1413
+ ProSSTForTokenClassification.register_for_auto_class("AutoModelForTokenClassification")