Jackmin108
commited on
Commit
•
76fc218
1
Parent(s):
c1736a8
feat: adapter masking finished
Browse filesSigned-off-by: Meow <[email protected]>
- block.py +11 -11
- embedding.py +20 -39
- mha.py +28 -37
- mlp.py +25 -19
- modeling_xlm_roberta.py +23 -28
block.py
CHANGED
@@ -233,17 +233,17 @@ class Block(nn.Module):
|
|
233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
234 |
)
|
235 |
if not isinstance(self.mlp, nn.Identity):
|
236 |
-
|
237 |
-
if
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
else:
|
246 |
-
|
247 |
if self.return_residual: # mlp out is actually a pair here
|
248 |
mlp_out, hidden_states = mlp_out
|
249 |
if not self.fused_dropout_add_ln:
|
|
|
233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
234 |
)
|
235 |
if not isinstance(self.mlp, nn.Identity):
|
236 |
+
mlp_out = self.mlp(hidden_states, cu_adapter_mask=mixer_kwargs.get('cu_adapter_mask'))
|
237 |
+
# if cu_adapter_mask:
|
238 |
+
# if isinstance(task_type, tuple):
|
239 |
+
# assert mixer_kwargs['cu_seqlens'].shape[0] % 9 == 1
|
240 |
+
# split_index = int((mixer_kwargs['cu_seqlens'].shape[0] - 1) / 9)
|
241 |
+
# split = mixer_kwargs['cu_seqlens'][split_index]
|
242 |
+
# mlp_out = self.mlp(hidden_states, task_type=mixer_kwargs.get('task_type'), split=split)
|
243 |
+
# else:
|
244 |
+
# mlp_out = self.mlp(hidden_states, task_type=task_type)
|
245 |
+
# else:
|
246 |
+
# mlp_out = self.mlp(hidden_states)
|
247 |
if self.return_residual: # mlp out is actually a pair here
|
248 |
mlp_out, hidden_states = mlp_out
|
249 |
if not self.fused_dropout_add_ln:
|
embedding.py
CHANGED
@@ -40,40 +40,25 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
40 |
if self.type_vocab_size > 0:
|
41 |
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
|
42 |
|
43 |
-
def forward(self, input_ids, position_ids=None, token_type_ids=None,
|
44 |
"""
|
45 |
input_ids: (batch, seqlen)
|
46 |
position_ids: (batch, seqlen)
|
47 |
token_type_ids: (batch, seqlen)
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
50 |
-
if
|
51 |
-
assert input_ids.shape[0] % 9 == 0
|
52 |
-
split = int(input_ids.shape[0] / 9)
|
53 |
-
tensor1 = input_ids[:split, :]
|
54 |
-
tensor2 = input_ids[split:, :]
|
55 |
-
emb1 = self.word_embeddings(tensor1, task_type=task_type[0])
|
56 |
-
emb2 = self.word_embeddings(tensor2, task_type=task_type[1])
|
57 |
-
embeddings = torch.cat((emb1, emb2), dim=0)
|
58 |
-
|
59 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
60 |
-
|
61 |
-
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
embeddings[
|
68 |
-
|
69 |
-
exit(0)
|
70 |
else:
|
71 |
-
|
72 |
-
task1_indices = (adapter_mask == unique_task).nonzero(as_tuple=True)[0]
|
73 |
-
input1 = input_ids[task1_indices]
|
74 |
-
lora_kwargs = {'task_type': unique_task} if unique_task is not None else {}
|
75 |
-
embeddings = self.word_embeddings(input1, **lora_kwargs)
|
76 |
-
|
77 |
|
78 |
if self.max_position_embeddings > 0:
|
79 |
if position_ids is None:
|
@@ -84,19 +69,15 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
84 |
if self.type_vocab_size > 0:
|
85 |
if token_type_ids is None:
|
86 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
emb1 = emb1 + token_type_embs1
|
95 |
-
emb2 = emb2 + token_type_embs2
|
96 |
-
embeddings = torch.cat((emb1, emb2), dim=0)
|
97 |
else:
|
98 |
-
|
99 |
-
lora_kwargs = {'task_type': unique_task} if unique_task is not None else {}
|
100 |
-
token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
|
101 |
embeddings = embeddings + token_type_embeddings
|
|
|
102 |
return embeddings
|
|
|
40 |
if self.type_vocab_size > 0:
|
41 |
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
|
42 |
|
43 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None):
|
44 |
"""
|
45 |
input_ids: (batch, seqlen)
|
46 |
position_ids: (batch, seqlen)
|
47 |
token_type_ids: (batch, seqlen)
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
50 |
+
if adapter_mask is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
52 |
+
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
53 |
+
embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
|
54 |
+
dtype=embedding_dtype).to(input_ids.device)
|
55 |
+
for task_id in unique_tasks:
|
56 |
+
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
57 |
+
task_input_ids = input_ids[task_indices]
|
58 |
+
task_embeddings = self.word_embeddings(task_input_ids, task_type=task_id)
|
59 |
+
embeddings[task_indices] = task_embeddings
|
|
|
|
|
60 |
else:
|
61 |
+
embeddings = self.word_embeddings(input_ids)
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
if self.max_position_embeddings > 0:
|
64 |
if position_ids is None:
|
|
|
69 |
if self.type_vocab_size > 0:
|
70 |
if token_type_ids is None:
|
71 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
72 |
+
|
73 |
+
if adapter_mask is not None:
|
74 |
+
unique_tasks = torch.unique(adapter_mask).tolist()
|
75 |
+
for task_id in unique_tasks:
|
76 |
+
task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_type=task_id)
|
77 |
+
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
78 |
+
embeddings[task_indices] = embeddings[task_indices] + task_token_type_embeddings
|
|
|
|
|
|
|
79 |
else:
|
80 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
|
|
|
81 |
embeddings = embeddings + token_type_embeddings
|
82 |
+
|
83 |
return embeddings
|
mha.py
CHANGED
@@ -590,7 +590,7 @@ class MHA(nn.Module):
|
|
590 |
max_seqlen=None,
|
591 |
mixer_subset=None,
|
592 |
inference_params=None,
|
593 |
-
|
594 |
**kwargs,
|
595 |
):
|
596 |
"""
|
@@ -643,39 +643,27 @@ class MHA(nn.Module):
|
|
643 |
inference_params.max_sequence_len if inference_params is not None else max_seqlen
|
644 |
)
|
645 |
batch, seqlen = x.shape[:2]
|
646 |
-
lora_kwargs = {}
|
647 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
648 |
assert x_kv is None and mixer_subset is None
|
649 |
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
qkv2 = self.Wqkv(tensor2, task_type=task_type[1])
|
664 |
-
qkv = torch.cat((qkv1, qkv2), dim=0)
|
665 |
-
else:
|
666 |
-
qkv = self.Wqkv(x, **lora_kwargs)
|
667 |
else:
|
668 |
-
if
|
669 |
-
|
670 |
-
tensor2 = x[split:, :]
|
671 |
-
qkv1, tensor1 = self.Wqkv(tensor1, task_type=task_type[0], residual=True)
|
672 |
-
qkv2, tensor2 = self.Wqkv(tensor2, task_type=task_type[1], residual=True)
|
673 |
-
qkv = torch.cat((qkv1, qkv2), dim=0)
|
674 |
-
x = torch.cat((tensor1, tensor2), dim=0)
|
675 |
else:
|
676 |
-
|
677 |
-
lora_kwargs['residual'] = True
|
678 |
-
qkv, x = self.Wqkv(x, **lora_kwargs)
|
679 |
|
680 |
if self.dwconv:
|
681 |
qkv = rearrange(
|
@@ -762,14 +750,17 @@ class MHA(nn.Module):
|
|
762 |
else:
|
763 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
764 |
|
765 |
-
lora_kwargs.pop('residual', None)
|
766 |
inp = rearrange(context, "... h d -> ... (h d)")
|
767 |
-
if
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
|
|
|
|
|
|
|
|
773 |
else:
|
774 |
-
out = self.out_proj(inp
|
775 |
return out if not self.return_residual else (out, x)
|
|
|
590 |
max_seqlen=None,
|
591 |
mixer_subset=None,
|
592 |
inference_params=None,
|
593 |
+
cu_adapter_mask=None,
|
594 |
**kwargs,
|
595 |
):
|
596 |
"""
|
|
|
643 |
inference_params.max_sequence_len if inference_params is not None else max_seqlen
|
644 |
)
|
645 |
batch, seqlen = x.shape[:2]
|
|
|
646 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
647 |
assert x_kv is None and mixer_subset is None
|
648 |
|
649 |
+
if cu_adapter_mask is not None:
|
650 |
+
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
651 |
+
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
652 |
+
qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
|
653 |
+
dtype=qkv_dtype).to(x.device)
|
654 |
+
for task_id in unique_tasks:
|
655 |
+
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
656 |
+
task_tensor = x[task_indices]
|
657 |
+
if not self.return_residual:
|
658 |
+
task_qkv = self.Wqkv(task_tensor, task_type=task_id)
|
659 |
+
else:
|
660 |
+
task_qkv, _ = self.Wqkv(task_tensor, task_type=task_id, residual=True)
|
661 |
+
qkv[task_indices] = task_qkv
|
|
|
|
|
|
|
|
|
662 |
else:
|
663 |
+
if not self.return_residual:
|
664 |
+
qkv = self.Wqkv(x)
|
|
|
|
|
|
|
|
|
|
|
665 |
else:
|
666 |
+
qkv, x = self.Wqkv(x)
|
|
|
|
|
667 |
|
668 |
if self.dwconv:
|
669 |
qkv = rearrange(
|
|
|
750 |
else:
|
751 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
752 |
|
|
|
753 |
inp = rearrange(context, "... h d -> ... (h d)")
|
754 |
+
if cu_adapter_mask is not None:
|
755 |
+
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
756 |
+
out_dtype = next(self.out_proj.parameters()).dtype
|
757 |
+
out = torch.empty(inp.shape[0], self.out_proj.out_features,
|
758 |
+
dtype=out_dtype).to(inp.device)
|
759 |
+
for task_id in unique_tasks:
|
760 |
+
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
761 |
+
task_tensor = inp[task_indices]
|
762 |
+
task_out = self.out_proj(task_tensor, task_type=task_id)
|
763 |
+
out[task_indices] = task_out
|
764 |
else:
|
765 |
+
out = self.out_proj(inp)
|
766 |
return out if not self.return_residual else (out, x)
|
mlp.py
CHANGED
@@ -47,30 +47,36 @@ class Mlp(nn.Module):
|
|
47 |
self.activation = activation
|
48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
49 |
|
50 |
-
def forward(self, x,
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
59 |
else:
|
60 |
-
y = self.fc1(x
|
61 |
|
62 |
y = self.activation(y)
|
63 |
|
64 |
-
if
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
71 |
else:
|
72 |
-
|
73 |
-
|
|
|
74 |
|
75 |
|
76 |
class ParallelMLP(nn.Module):
|
|
|
47 |
self.activation = activation
|
48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
49 |
|
50 |
+
def forward(self, x, cu_adapter_mask=None):
|
51 |
+
if cu_adapter_mask is not None:
|
52 |
+
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
53 |
+
fc1_dtype = next(self.fc1.parameters()).dtype
|
54 |
+
y = torch.empty(x.shape[0], self.fc1.out_features,
|
55 |
+
dtype=fc1_dtype).to(x.device)
|
56 |
+
for task_id in unique_tasks:
|
57 |
+
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
58 |
+
task_tensor = x[task_indices]
|
59 |
+
task_y = self.fc1(task_tensor, task_type=task_id)
|
60 |
+
y[task_indices] = task_y
|
61 |
else:
|
62 |
+
y = self.fc1(x)
|
63 |
|
64 |
y = self.activation(y)
|
65 |
|
66 |
+
if cu_adapter_mask is not None:
|
67 |
+
unique_tasks = torch.unique(cu_adapter_mask).tolist()
|
68 |
+
fc2_dtype = next(self.fc2.parameters()).dtype
|
69 |
+
out = torch.empty(y.shape[0], self.fc2.out_features,
|
70 |
+
dtype=fc2_dtype).to(y.device)
|
71 |
+
for task_id in unique_tasks:
|
72 |
+
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
73 |
+
task_tensor = y[task_indices]
|
74 |
+
task_out = self.fc2(task_tensor, task_type=task_id)
|
75 |
+
out[task_indices] = task_out
|
76 |
else:
|
77 |
+
out = self.fc1(y)
|
78 |
+
|
79 |
+
return out if not self.return_residual else (out, x)
|
80 |
|
81 |
|
82 |
class ParallelMLP(nn.Module):
|
modeling_xlm_roberta.py
CHANGED
@@ -204,18 +204,16 @@ class XLMRobertaEncoder(nn.Module):
|
|
204 |
def gradient_checkpointing(self, value):
|
205 |
self._grad_checkpointing = value
|
206 |
|
207 |
-
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None,
|
208 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
209 |
This means that we only compute the last layer output for these tokens.
|
210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
211 |
"""
|
212 |
if key_padding_mask is None or not self.use_flash_attn:
|
213 |
-
mixer_kwargs =
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
)
|
218 |
-
mixer_kwargs['task_type'] = task_type
|
219 |
for layer in self.layers:
|
220 |
if self._grad_checkpointing:
|
221 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
@@ -233,7 +231,8 @@ class XLMRobertaEncoder(nn.Module):
|
|
233 |
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
|
234 |
hidden_states, key_padding_mask, adapter_mask
|
235 |
)
|
236 |
-
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "
|
|
|
237 |
if subset_mask is None:
|
238 |
for layer in self.layers:
|
239 |
if self._grad_checkpointing:
|
@@ -310,24 +309,22 @@ class XLMRobertaPooler(nn.Module):
|
|
310 |
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
311 |
self.activation = nn.Tanh()
|
312 |
|
313 |
-
def forward(self, hidden_states, pool=True,
|
314 |
# We "pool" the model by simply taking the hidden state corresponding
|
315 |
# to the first token.
|
316 |
-
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
317 |
-
|
318 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
|
|
328 |
else:
|
329 |
-
pooled_output = self.dense(first_token_tensor
|
330 |
-
|
331 |
pooled_output = self.activation(pooled_output)
|
332 |
return pooled_output
|
333 |
|
@@ -440,7 +437,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
440 |
"gelu_fast",
|
441 |
"gelu_pytorch_tanh",
|
442 |
]
|
443 |
-
|
444 |
self.embeddings = XLMRobertaEmbeddings(
|
445 |
config.hidden_size,
|
446 |
config.vocab_size,
|
@@ -648,7 +644,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
648 |
layer output for these tokens.
|
649 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
650 |
"""
|
651 |
-
task_type = kwargs.pop('task_type', None)
|
652 |
adapter_mask = kwargs.pop('adapter_mask', None)
|
653 |
if kwargs:
|
654 |
for key, value in kwargs.items():
|
@@ -663,7 +658,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
663 |
)
|
664 |
|
665 |
hidden_states = self.embeddings(
|
666 |
-
input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
667 |
)
|
668 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
669 |
# BERT puts embedding LayerNorm before embedding dropout.
|
@@ -687,12 +682,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
687 |
subset_mask = None
|
688 |
|
689 |
sequence_output = self.encoder(
|
690 |
-
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask,
|
691 |
)
|
692 |
|
693 |
if masked_tokens_mask is None:
|
694 |
pooled_output = (
|
695 |
-
self.pooler(sequence_output,
|
696 |
)
|
697 |
else:
|
698 |
# TD [2022-03-01]: the indexing here is very tricky.
|
@@ -706,7 +701,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
706 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
707 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
708 |
pooled_output = (
|
709 |
-
self.pooler(pool_input, pool=False,
|
710 |
)
|
711 |
|
712 |
if not return_dict:
|
|
|
204 |
def gradient_checkpointing(self, value):
|
205 |
self._grad_checkpointing = value
|
206 |
|
207 |
+
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, adapter_mask=None):
|
208 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
209 |
This means that we only compute the last layer output for these tokens.
|
210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
211 |
"""
|
212 |
if key_padding_mask is None or not self.use_flash_attn:
|
213 |
+
mixer_kwargs = {'adapter_mask': adapter_mask}
|
214 |
+
if key_padding_mask is not None:
|
215 |
+
mixer_kwargs['key_padding_mask'] = key_padding_mask.bool()
|
216 |
+
|
|
|
|
|
217 |
for layer in self.layers:
|
218 |
if self._grad_checkpointing:
|
219 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
|
231 |
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
|
232 |
hidden_states, key_padding_mask, adapter_mask
|
233 |
)
|
234 |
+
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "cu_adapter_mask": cu_adapter_mask}
|
235 |
+
|
236 |
if subset_mask is None:
|
237 |
for layer in self.layers:
|
238 |
if self._grad_checkpointing:
|
|
|
309 |
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
310 |
self.activation = nn.Tanh()
|
311 |
|
312 |
+
def forward(self, hidden_states, pool=True, adapter_mask=None):
|
313 |
# We "pool" the model by simply taking the hidden state corresponding
|
314 |
# to the first token.
|
|
|
|
|
315 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
316 |
+
if adapter_mask is not None:
|
317 |
+
unique_tasks = torch.unique(adapter_mask).tolist()
|
318 |
+
pool_dtype = next(self.dense.parameters()).dtype
|
319 |
+
pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
|
320 |
+
dtype=pool_dtype).to(first_token_tensor.device)
|
321 |
+
for task_id in unique_tasks:
|
322 |
+
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
323 |
+
task_first_token_tensor = first_token_tensor[task_indices]
|
324 |
+
task_pooled_output = self.dense(task_first_token_tensor, task_type=task_id)
|
325 |
+
pooled_output[task_indices] = task_pooled_output
|
326 |
else:
|
327 |
+
pooled_output = self.dense(first_token_tensor)
|
|
|
328 |
pooled_output = self.activation(pooled_output)
|
329 |
return pooled_output
|
330 |
|
|
|
437 |
"gelu_fast",
|
438 |
"gelu_pytorch_tanh",
|
439 |
]
|
|
|
440 |
self.embeddings = XLMRobertaEmbeddings(
|
441 |
config.hidden_size,
|
442 |
config.vocab_size,
|
|
|
644 |
layer output for these tokens.
|
645 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
646 |
"""
|
|
|
647 |
adapter_mask = kwargs.pop('adapter_mask', None)
|
648 |
if kwargs:
|
649 |
for key, value in kwargs.items():
|
|
|
658 |
)
|
659 |
|
660 |
hidden_states = self.embeddings(
|
661 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, adapter_mask=adapter_mask
|
662 |
)
|
663 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
664 |
# BERT puts embedding LayerNorm before embedding dropout.
|
|
|
682 |
subset_mask = None
|
683 |
|
684 |
sequence_output = self.encoder(
|
685 |
+
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, adapter_mask=adapter_mask
|
686 |
)
|
687 |
|
688 |
if masked_tokens_mask is None:
|
689 |
pooled_output = (
|
690 |
+
self.pooler(sequence_output, adapter_mask=adapter_mask) if self.pooler is not None else None
|
691 |
)
|
692 |
else:
|
693 |
# TD [2022-03-01]: the indexing here is very tricky.
|
|
|
701 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
702 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
703 |
pooled_output = (
|
704 |
+
self.pooler(pool_input, pool=False, adapter_mask=adapter_mask) if self.pooler is not None else None
|
705 |
)
|
706 |
|
707 |
if not return_dict:
|