Logic123456789 commited on
Commit
e16b4bd
1 Parent(s): 915e12e

add the models.py

Browse files
Files changed (1) hide show
  1. models.py +524 -0
models.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+ from simcse.modeling_glm import GLMModel, GLMPreTrainedModel
7
+ import simcse.mse_loss
8
+
9
+ import transformers
10
+ from transformers import RobertaTokenizer, AutoModel, PreTrainedModel
11
+ from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead
12
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead
13
+ from transformers.activations import gelu
14
+ from transformers.file_utils import (
15
+ add_code_sample_docstrings,
16
+ add_start_docstrings,
17
+ add_start_docstrings_to_model_forward,
18
+ replace_return_docstrings,
19
+ )
20
+ from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
21
+
22
+ glm_model = None
23
+
24
+ def init_glm(path):
25
+ global glm_model
26
+ glm_model = GLMModel.from_pretrained(path, trust_remote_code=True).to("cuda:0")
27
+ for param in glm_model.parameters():
28
+ param.requires_grad = False
29
+
30
+
31
+
32
+ class MLPLayer(nn.Module):
33
+ """
34
+ Head for getting sentence representations over RoBERTa/BERT's CLS representation.
35
+ """
36
+
37
+ def __init__(self, config):
38
+ super().__init__()
39
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
40
+ # 1536
41
+ self.fc = nn.Linear(config.hidden_size, 1536)
42
+ self.activation = nn.Tanh()
43
+
44
+ def forward(self, features, **kwargs):
45
+ x = self.dense(features)
46
+ x = self.fc(x)
47
+ x = self.activation(x)
48
+
49
+ return x
50
+
51
+ class Similarity(nn.Module):
52
+ """
53
+ Dot product or cosine similarity
54
+ """
55
+
56
+ def __init__(self, temp):
57
+ super().__init__()
58
+ self.temp = temp
59
+ self.cos = nn.CosineSimilarity(dim=-1)
60
+
61
+ def forward(self, x, y):
62
+ return self.cos(x, y) / self.temp
63
+
64
+
65
+ class Pooler(nn.Module):
66
+ """
67
+ Parameter-free poolers to get the sentence embedding
68
+ 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler.
69
+ 'cls_before_pooler': [CLS] representation without the original MLP pooler.
70
+ 'avg': average of the last layers' hidden states at each token.
71
+ 'avg_top2': average of the last two layers.
72
+ 'avg_first_last': average of the first and the last layers.
73
+ """
74
+
75
+ def __init__(self, pooler_type):
76
+ super().__init__()
77
+ self.pooler_type = pooler_type
78
+ assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2",
79
+ "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type
80
+
81
+ def forward(self, attention_mask, outputs):
82
+ last_hidden = outputs.last_hidden_state
83
+ # pooler_output = outputs.pooler_output
84
+ hidden_states = outputs.hidden_states
85
+
86
+ if self.pooler_type in ['cls_before_pooler', 'cls']:
87
+ return last_hidden[:, 0]
88
+ elif self.pooler_type == "avg":
89
+ return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1))
90
+ elif self.pooler_type == "avg_first_last":
91
+ first_hidden = hidden_states[1]
92
+ last_hidden = hidden_states[-1]
93
+ pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(
94
+ 1) / attention_mask.sum(-1).unsqueeze(-1)
95
+ return pooled_result
96
+ elif self.pooler_type == "avg_top2":
97
+ second_last_hidden = hidden_states[-2]
98
+ last_hidden = hidden_states[-1]
99
+ pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(
100
+ 1) / attention_mask.sum(-1).unsqueeze(-1)
101
+ return pooled_result
102
+ else:
103
+ raise NotImplementedError
104
+
105
+
106
+ def cl_init(cls, config):
107
+ """
108
+ Contrastive learning class init function.
109
+ """
110
+ cls.pooler_type = cls.model_args.pooler_type
111
+ cls.pooler = Pooler(cls.model_args.pooler_type)
112
+ if cls.model_args.pooler_type == "cls":
113
+ cls.mlp = MLPLayer(config)
114
+ cls.sim = Similarity(temp=cls.model_args.temp)
115
+ cls.init_weights()
116
+
117
+
118
+ def cl_forward(cls,
119
+ encoder,
120
+ input_ids=None,
121
+ attention_mask=None,
122
+ token_type_ids=None,
123
+ position_ids=None,
124
+ head_mask=None,
125
+ inputs_embeds=None,
126
+ labels=None,
127
+ output_attentions=None,
128
+ output_hidden_states=None,
129
+ return_dict=None,
130
+ mlm_input_ids=None,
131
+ mlm_labels=None,
132
+ left_emb=None,
133
+ right_emb=None,
134
+ kl_loss=False
135
+ ):
136
+ return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
137
+ ori_input_ids = input_ids
138
+ batch_size = input_ids.size(0)
139
+ # Number of sentences in one instance
140
+ # 2: pair instance; 3: pair instance with a hard negative
141
+ num_sent = input_ids.size(1)
142
+
143
+ mlm_outputs = None
144
+ # Flatten input for encoding
145
+ input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
146
+ attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
147
+ if token_type_ids is not None:
148
+ token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)
149
+
150
+ if inputs_embeds is not None:
151
+ input_ids = None
152
+
153
+ # Get raw embeddings
154
+ outputs = encoder(
155
+ input_ids,
156
+ attention_mask=attention_mask,
157
+ token_type_ids=token_type_ids,
158
+ position_ids=position_ids,
159
+ head_mask=head_mask,
160
+ inputs_embeds=inputs_embeds,
161
+ output_attentions=output_attentions,
162
+ output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
163
+ return_dict=True,
164
+ )
165
+
166
+ # MLM auxiliary objective
167
+ if mlm_input_ids is not None:
168
+ mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1)))
169
+ mlm_outputs = encoder(
170
+ mlm_input_ids,
171
+ attention_mask=attention_mask,
172
+ token_type_ids=token_type_ids,
173
+ position_ids=position_ids,
174
+ head_mask=head_mask,
175
+ inputs_embeds=inputs_embeds,
176
+ output_attentions=output_attentions,
177
+ output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
178
+ return_dict=True,
179
+ )
180
+
181
+ # Pooling
182
+ pooler_output = cls.pooler(attention_mask, outputs)
183
+ pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden)
184
+ # If using "cls", we add an extra MLP layer
185
+ # (same as BERT's original implementation) over the representation.
186
+ if cls.pooler_type == "cls":
187
+ pooler_output = cls.mlp(pooler_output)
188
+
189
+ # Separate representation
190
+ z1, z2 = pooler_output[:, 0], pooler_output[:, 1]
191
+
192
+ tensor_left = left_emb
193
+ tensor_right = right_emb
194
+
195
+ # Hard negative
196
+ if num_sent == 3:
197
+ z3 = pooler_output[:, 2]
198
+
199
+ # Gather all embeddings if using distributed training
200
+ if dist.is_initialized() and cls.training:
201
+ # Gather hard negative
202
+ if num_sent >= 3:
203
+ z3_list = [torch.zeros_like(z3) for _ in range(dist.get_world_size())]
204
+ dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous())
205
+ z3_list[dist.get_rank()] = z3
206
+ z3 = torch.cat(z3_list, 0)
207
+
208
+ # Dummy vectors for allgather
209
+ z1_list = [torch.zeros_like(z1) for _ in range(dist.get_world_size())]
210
+ z2_list = [torch.zeros_like(z2) for _ in range(dist.get_world_size())]
211
+ # Allgather
212
+ dist.all_gather(tensor_list=z1_list, tensor=z1.contiguous())
213
+ dist.all_gather(tensor_list=z2_list, tensor=z2.contiguous())
214
+
215
+ # Since allgather results do not have gradients, we replace the
216
+ # current process's corresponding embeddings with original tensors
217
+ z1_list[dist.get_rank()] = z1
218
+ z2_list[dist.get_rank()] = z2
219
+ # Get full batch embeddings: (bs x N, hidden)
220
+ z1 = torch.cat(z1_list, 0)
221
+ z2 = torch.cat(z2_list, 0)
222
+
223
+ mse_loss = F.mse_loss(z1, tensor_left) + F.mse_loss(z2, tensor_right)
224
+
225
+ # softmax_row, softmax_col = simcse.mse_loss.giveMeMatrix(tensor_left, tensor_right)
226
+ # softmax_row_model, softmax_col_model = simcse.mse_loss.giveMeMatrix(z1,z2)
227
+ # ziang_labels = torch.tensor([i for i in range(8)], device='cuda:0')
228
+
229
+ """
230
+ this is KL div loss
231
+ """
232
+
233
+ KL_loss = nn.KLDivLoss(reduction="batchmean")
234
+ beta = 5
235
+
236
+ # openai的embed,giveMeMatrix返回一个normalized过前后向量,相乘后的矩阵
237
+ cos_sim_matrix_openai = simcse.mse_loss.giveMeMatrix(tensor_left, tensor_right)
238
+ beta_scaled_cos_sim_matrix_openai = beta * cos_sim_matrix_openai
239
+
240
+ # 我们的embed,giveMeMatrix返回一个normalized过前后向量,相乘后的矩阵
241
+ cos_sim_matrix_data = simcse.mse_loss.giveMeMatrix(z1, z2)
242
+ beta_scaled_cos_sim_matrix_data = beta * cos_sim_matrix_data
243
+
244
+ beta_scaled_cos_sim_matrix_openai_vertical = beta_scaled_cos_sim_matrix_openai.softmax(dim=1)
245
+ beta_scaled_cos_sim_matrix_openai_horizontal = beta_scaled_cos_sim_matrix_openai.softmax(dim=0)
246
+
247
+ beta_scaled_cos_sim_matrix_data_vertical = beta_scaled_cos_sim_matrix_data.softmax(dim=1)
248
+ beta_scaled_cos_sim_matrix_data_horizontal = beta_scaled_cos_sim_matrix_data.softmax(dim=0)
249
+
250
+ # remove reduction="batchmean"
251
+ KL_vertical_loss = KL_loss(beta_scaled_cos_sim_matrix_data_vertical.log(), beta_scaled_cos_sim_matrix_openai_vertical)
252
+ KL_horizontal_loss = KL_loss(beta_scaled_cos_sim_matrix_data_horizontal.log(), beta_scaled_cos_sim_matrix_openai_horizontal)
253
+
254
+ KL_loss = (KL_vertical_loss + KL_horizontal_loss) / 2
255
+
256
+ # KL_row_loss = F.kl_div(softmax_row_model.log(), softmax_row, reduction='batchmean')
257
+ # KL_col_loss = F.kl_div(softmax_col_model.log(), softmax_col, reduction='batchmean')
258
+ # KL_loss = (KL_row_loss + KL_col_loss) / 2
259
+
260
+ ziang_loss = KL_loss + mse_loss
261
+
262
+ cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0))
263
+
264
+ # Hard negative
265
+ if num_sent >= 3:
266
+ z1_z3_cos = cls.sim(z1.unsqueeze(1), z3.unsqueeze(0))
267
+ cos_sim = torch.cat([cos_sim, z1_z3_cos], 1)
268
+
269
+ labels = torch.arange(cos_sim.size(0)).long().to(cls.device)
270
+ loss_fct = nn.CrossEntropyLoss()
271
+
272
+ # Calculate loss with hard negatives
273
+ if num_sent == 3:
274
+ # Note that weights are actually logits of weights
275
+ z3_weight = cls.model_args.hard_negative_weight
276
+ weights = torch.tensor(
277
+ [[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (
278
+ z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))]
279
+ ).to(cls.device)
280
+ cos_sim = cos_sim + weights
281
+
282
+ loss = loss_fct(cos_sim, labels)
283
+
284
+ # Calculate loss for MLM
285
+ if mlm_outputs is not None and mlm_labels is not None:
286
+ mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
287
+ prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state)
288
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1))
289
+ loss = loss + cls.model_args.mlm_weight * masked_lm_loss
290
+
291
+ if not return_dict:
292
+ output = (cos_sim,) + outputs[2:]
293
+ return ((loss,) + output) if loss is not None else output
294
+
295
+ return SequenceClassifierOutput(
296
+ loss=ziang_loss,
297
+ logits=cos_sim,
298
+ hidden_states=outputs.hidden_states,
299
+ )
300
+
301
+
302
+ def sentemb_forward(
303
+ cls,
304
+ encoder,
305
+ input_ids=None,
306
+ attention_mask=None,
307
+ token_type_ids=None,
308
+ position_ids=None,
309
+ head_mask=None,
310
+ inputs_embeds=None,
311
+ labels=None,
312
+ output_attentions=None,
313
+ output_hidden_states=None,
314
+ return_dict=None,
315
+ ):
316
+ return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
317
+
318
+ if inputs_embeds is not None:
319
+ input_ids = None
320
+
321
+ outputs = encoder(
322
+ input_ids,
323
+ attention_mask=attention_mask,
324
+ token_type_ids=token_type_ids,
325
+ position_ids=position_ids,
326
+ head_mask=head_mask,
327
+ inputs_embeds=inputs_embeds,
328
+ output_attentions=output_attentions,
329
+ output_hidden_states=True if cls.pooler_type in ['avg_top2', 'avg_first_last'] else False,
330
+ return_dict=True,
331
+ )
332
+
333
+ pooler_output = cls.pooler(attention_mask, outputs)
334
+ if cls.pooler_type == "cls" and not cls.model_args.mlp_only_train:
335
+ pooler_output = cls.mlp(pooler_output)
336
+
337
+ if not return_dict:
338
+ return (outputs[0], pooler_output) + outputs[2:]
339
+
340
+ return BaseModelOutputWithPoolingAndCrossAttentions(
341
+ pooler_output=pooler_output,
342
+ last_hidden_state=outputs.last_hidden_state,
343
+ hidden_states=outputs.hidden_states,
344
+ )
345
+
346
+
347
+ class BertForCL(BertPreTrainedModel):
348
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
349
+
350
+ def __init__(self, config, *model_args, **model_kargs):
351
+ super().__init__(config)
352
+ self.model_args = model_kargs["model_args"]
353
+ self.bert = BertModel(config, add_pooling_layer=False)
354
+
355
+ if self.model_args.do_mlm:
356
+ self.lm_head = BertLMPredictionHead(config)
357
+
358
+ if self.model_args.init_embeddings_model:
359
+ if "glm" in self.model_args.init_embeddings_model:
360
+ init_glm(self.model_args.init_embeddings_model)
361
+ self.fc = nn.Linear(glm_model.config.hidden_size, config.hidden_size)
362
+ else:
363
+ raise NotImplementedError
364
+
365
+ cl_init(self, config)
366
+
367
+ def forward(self,
368
+ input_ids=None,
369
+ attention_mask=None,
370
+ token_type_ids=None,
371
+ position_ids=None,
372
+ head_mask=None,
373
+ inputs_embeds=None,
374
+ labels=None,
375
+ output_attentions=None,
376
+ output_hidden_states=None,
377
+ return_dict=None,
378
+ sent_emb=False,
379
+ mlm_input_ids=None,
380
+ mlm_labels=None,
381
+ left_emb=None,
382
+ right_emb=None,
383
+ ):
384
+ if self.model_args.init_embeddings_model:
385
+ input_ids_for_glm = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
386
+ attention_mask_for_glm = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
387
+ if token_type_ids is not None:
388
+ token_type_ids_for_glm = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)
389
+
390
+ outputs_from_glm = glm_model(input_ids_for_glm,
391
+ attention_mask=attention_mask_for_glm,
392
+ token_type_ids=token_type_ids_for_glm,
393
+ position_ids=position_ids,
394
+ head_mask=head_mask,
395
+ inputs_embeds=inputs_embeds,
396
+ labels=labels,
397
+ output_attentions=output_attentions,
398
+ output_hidden_states=output_hidden_states,
399
+ return_dict=return_dict,
400
+ )
401
+
402
+ inputs_embeds = self.fc(outputs_from_glm.last_hidden_state)
403
+
404
+ if sent_emb:
405
+ return sentemb_forward(self, self.bert,
406
+ input_ids=input_ids,
407
+ attention_mask=attention_mask,
408
+ token_type_ids=token_type_ids,
409
+ position_ids=position_ids,
410
+ head_mask=head_mask,
411
+ inputs_embeds=inputs_embeds,
412
+ labels=labels,
413
+ output_attentions=output_attentions,
414
+ output_hidden_states=output_hidden_states,
415
+ return_dict=return_dict,
416
+ )
417
+ else:
418
+ return cl_forward(self, self.bert,
419
+ input_ids=input_ids,
420
+ attention_mask=attention_mask,
421
+ token_type_ids=token_type_ids,
422
+ position_ids=position_ids,
423
+ head_mask=head_mask,
424
+ inputs_embeds=inputs_embeds,
425
+ labels=labels,
426
+ output_attentions=output_attentions,
427
+ output_hidden_states=output_hidden_states,
428
+ return_dict=return_dict,
429
+ mlm_input_ids=mlm_input_ids,
430
+ mlm_labels=mlm_labels,
431
+ left_emb=left_emb,
432
+ right_emb=right_emb,
433
+ )
434
+
435
+
436
+ class RobertaForCL(RobertaPreTrainedModel):
437
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
438
+
439
+ def __init__(self, config, *model_args, **model_kargs):
440
+ super().__init__(config)
441
+ self.model_args = model_kargs["model_args"]
442
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
443
+
444
+ if self.model_args.do_mlm:
445
+ self.lm_head = RobertaLMHead(config)
446
+
447
+ if self.model_args.init_embeddings_model:
448
+ if "glm" in self.model_args.init_embeddings_model:
449
+ init_glm(self.model_args.init_embeddings_model)
450
+ self.fc = nn.Linear(glm_model.config.hidden_size, config.hidden_size)
451
+ else:
452
+ raise NotImplementedError
453
+
454
+ cl_init(self, config)
455
+
456
+ def forward(self,
457
+ input_ids=None,
458
+ attention_mask=None,
459
+ token_type_ids=None,
460
+ position_ids=None,
461
+ head_mask=None,
462
+ inputs_embeds=None,
463
+ labels=None,
464
+ output_attentions=None,
465
+ output_hidden_states=None,
466
+ return_dict=None,
467
+ sent_emb=False,
468
+ mlm_input_ids=None,
469
+ mlm_labels=None,
470
+ left_emb=None,
471
+ right_emb=None,
472
+ ):
473
+
474
+ if self.model_args.init_embeddings_model and not sent_emb:
475
+ input_ids_for_glm = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
476
+ attention_mask_for_glm = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
477
+ if token_type_ids is not None:
478
+ token_type_ids_for_glm = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)
479
+
480
+ outputs_from_glm = glm_model(input_ids_for_glm,
481
+ attention_mask=attention_mask_for_glm,
482
+ token_type_ids=token_type_ids_for_glm,
483
+ position_ids=position_ids,
484
+ head_mask=head_mask,
485
+ inputs_embeds=inputs_embeds,
486
+ labels=labels,
487
+ output_attentions=output_attentions,
488
+ output_hidden_states=output_hidden_states,
489
+ return_dict=return_dict,
490
+ )
491
+
492
+ inputs_embeds = self.fc(outputs_from_glm.last_hidden_state)
493
+
494
+ if sent_emb:
495
+ return sentemb_forward(self, self.roberta,
496
+ input_ids=input_ids,
497
+ attention_mask=attention_mask,
498
+ token_type_ids=token_type_ids,
499
+ position_ids=position_ids,
500
+ head_mask=head_mask,
501
+ inputs_embeds=inputs_embeds,
502
+ labels=labels,
503
+ output_attentions=output_attentions,
504
+ output_hidden_states=output_hidden_states,
505
+ return_dict=return_dict,
506
+ )
507
+ else:
508
+ return cl_forward(self, self.roberta,
509
+ input_ids=input_ids,
510
+ attention_mask=attention_mask,
511
+ token_type_ids=token_type_ids,
512
+ position_ids=position_ids,
513
+ head_mask=head_mask,
514
+ inputs_embeds=inputs_embeds,
515
+ labels=labels,
516
+ output_attentions=output_attentions,
517
+ output_hidden_states=output_hidden_states,
518
+ return_dict=return_dict,
519
+ mlm_input_ids=mlm_input_ids,
520
+ mlm_labels=mlm_labels,
521
+ left_emb=left_emb,
522
+ right_emb=right_emb,
523
+ )
524
+