Jiqing commited on
Commit
0329051
1 Parent(s): a1cddb1

Create modeling_protst.py

Browse files
Files changed (1) hide show
  1. modeling_protst.py +214 -0
modeling_protst.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Optional, Tuple, Union
5
+ from dataclasses import dataclass
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import ModelOutput
8
+ from transformers.models.esm import EsmPreTrainedModel, EsmModel
9
+ from transformers.models.bert import BertPreTrainedModel, BertModel
10
+ from .configuration_protst import ProtSTConfig
11
+
12
+
13
+ @dataclass
14
+ class EsmProteinRepresentationOutput(ModelOutput):
15
+
16
+ protein_feature: torch.FloatTensor = None
17
+ residue_feature: torch.FloatTensor = None
18
+
19
+
20
+ @dataclass
21
+ class BertTextRepresentationOutput(ModelOutput):
22
+
23
+ text_feature: torch.FloatTensor = None
24
+ word_feature: torch.FloatTensor = None
25
+
26
+
27
+ @dataclass
28
+ class ProtSTClassificationOutput(ModelOutput):
29
+
30
+ loss: Optional[torch.FloatTensor] = None
31
+ logits: torch.FloatTensor = None
32
+
33
+ class ProtSTHead(nn.Module):
34
+ def __init__(self, config, out_dim=512):
35
+ super().__init__()
36
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
37
+ self.out_proj = nn.Linear(config.hidden_size, out_dim)
38
+
39
+ def forward(self, x):
40
+ x = self.dense(x)
41
+ x = nn.functional.relu(x)
42
+ x = self.out_proj(x)
43
+ return x
44
+
45
+
46
+ class BertForPubMed(BertPreTrainedModel):
47
+ def __init__(self, config):
48
+ super().__init__(config)
49
+
50
+ self.pad_token_id = config.pad_token_id
51
+ self.cls_token_id = config.cls_token_id
52
+ self.sep_token_id = config.sep_token_id
53
+
54
+ self.bert = BertModel(config, add_pooling_layer=False)
55
+ self.text_mlp = ProtSTHead(config)
56
+ self.word_mlp = ProtSTHead(config)
57
+
58
+ self.post_init() # NOTE
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.Tensor] = None,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ token_type_ids: Optional[torch.Tensor] = None,
65
+ position_ids: Optional[torch.Tensor] = None,
66
+ head_mask: Optional[torch.Tensor] = None,
67
+ inputs_embeds: Optional[torch.Tensor] = None,
68
+ encoder_hidden_states: Optional[torch.Tensor] = None,
69
+ encoder_attention_mask: Optional[torch.Tensor] = None,
70
+ output_attentions: Optional[bool] = None,
71
+ output_hidden_states: Optional[bool] = None,
72
+ return_dict: Optional[bool] = None,
73
+ ) -> Union[Tuple[torch.Tensor], ModelOutput]:
74
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
75
+
76
+ outputs = self.bert(
77
+ input_ids,
78
+ attention_mask=attention_mask,
79
+ token_type_ids=token_type_ids,
80
+ position_ids=position_ids,
81
+ head_mask=head_mask,
82
+ inputs_embeds=inputs_embeds,
83
+ encoder_hidden_states=encoder_hidden_states,
84
+ encoder_attention_mask=encoder_attention_mask,
85
+ output_attentions=output_attentions,
86
+ output_hidden_states=output_hidden_states,
87
+ return_dict=return_dict,
88
+ )
89
+ word_feature = outputs.last_hidden_state
90
+ is_special = (input_ids == self.cls_token_id) | (input_ids == self.sep_token_id) | (input_ids == self.pad_token_id)
91
+ special_mask = (~is_special).to(torch.int64).unsqueeze(-1)
92
+ pooled_feature = ((word_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(word_feature.dtype)
93
+ pooled_feature = self.text_mlp(pooled_feature)
94
+ word_feature = self.word_mlp(word_feature)
95
+
96
+ if not return_dict:
97
+ return (pooled_feature, word_feature)
98
+
99
+ return BertTextRepresentationOutput(text_feature=pooled_feature, word_feature=word_feature)
100
+
101
+
102
+
103
+
104
+ class EsmForProteinRepresentation(EsmPreTrainedModel):
105
+ def __init__(self, config):
106
+ super().__init__(config)
107
+
108
+ self.cls_token_id = config.cls_token_id
109
+ self.pad_token_id = config.pad_token_id
110
+ self.eos_token_id = config.eos_token_id
111
+
112
+ self.esm = EsmModel(config, add_pooling_layer=False)
113
+
114
+ self.post_init() # NOTE
115
+
116
+ def forward(
117
+ self,
118
+ input_ids: Optional[torch.LongTensor] = None,
119
+ attention_mask: Optional[torch.Tensor] = None,
120
+ position_ids: Optional[torch.LongTensor] = None,
121
+ head_mask: Optional[torch.Tensor] = None,
122
+ inputs_embeds: Optional[torch.FloatTensor] = None,
123
+ output_attentions: Optional[bool] = None,
124
+ output_hidden_states: Optional[bool] = None,
125
+ return_dict: Optional[bool] = None,
126
+ ) -> Union[Tuple, EsmProteinRepresentationOutput]:
127
+
128
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
129
+
130
+ outputs = self.esm(
131
+ input_ids,
132
+ attention_mask=attention_mask,
133
+ position_ids=position_ids,
134
+ head_mask=head_mask,
135
+ inputs_embeds=inputs_embeds,
136
+ output_attentions=output_attentions,
137
+ output_hidden_states=output_hidden_states,
138
+ return_dict=return_dict,
139
+ )
140
+
141
+ residue_feature = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
142
+
143
+ # mean readout
144
+ is_special = (
145
+ (input_ids == self.cls_token_id) | (input_ids == self.eos_token_id) | (input_ids == self.pad_token_id)
146
+ )
147
+ special_mask = (~is_special).to(torch.int64).unsqueeze(-1)
148
+ protein_feature = ((residue_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(residue_feature.dtype)
149
+
150
+ return EsmProteinRepresentationOutput(
151
+ protein_feature=protein_feature, residue_feature=residue_feature
152
+ )
153
+
154
+
155
+ class ProtSTPreTrainedModel(PreTrainedModel):
156
+ config_class = ProtSTConfig
157
+
158
+
159
+ class ProtSTForProteinPropertyPrediction(ProtSTPreTrainedModel):
160
+ def __init__(self, config):
161
+ super().__init__(config)
162
+
163
+ self.config = config
164
+ self.protein_model = EsmForProteinRepresentation(config.protein_config)
165
+ self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
166
+ self.classifier = ProtSTHead(config.protein_config, out_dim=config.num_labels)
167
+
168
+ self.post_init() # NOTE
169
+
170
+ def forward(
171
+ self,
172
+ input_ids: Optional[torch.LongTensor] = None,
173
+ attention_mask: Optional[torch.Tensor] = None,
174
+ position_ids: Optional[torch.LongTensor] = None,
175
+ head_mask: Optional[torch.Tensor] = None,
176
+ inputs_embeds: Optional[torch.FloatTensor] = None,
177
+ labels: Optional[torch.LongTensor] = None,
178
+ output_attentions: Optional[bool] = None,
179
+ output_hidden_states: Optional[bool] = None,
180
+ return_dict: Optional[bool] = None,
181
+ ) -> Union[Tuple, ProtSTClassificationOutput]:
182
+ r"""
183
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
184
+ Labels for computing the protein classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
185
+ Returns:
186
+ Examples:
187
+ """
188
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
189
+
190
+ outputs = self.protein_model(
191
+ input_ids,
192
+ attention_mask=attention_mask,
193
+ position_ids=position_ids,
194
+ head_mask=head_mask,
195
+ inputs_embeds=inputs_embeds,
196
+ output_attentions=output_attentions,
197
+ output_hidden_states=output_hidden_states,
198
+ return_dict=return_dict,
199
+ )
200
+
201
+ logits = self.classifier(outputs.protein_feature) # [bsz, xxx] -> [bsz, num_labels]
202
+
203
+ loss = None
204
+ if labels is not None:
205
+ loss_fct = nn.CrossEntropyLoss()
206
+
207
+ labels = labels.to(logits.device)
208
+ loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
209
+
210
+ if not return_dict:
211
+ output = (logits,)
212
+ return ((loss,) + output) if loss is not None else output
213
+
214
+ return ProtSTClassificationOutput(loss=loss, logits=logits)