michael-guenther
commited on
Commit
•
77af1c7
1
Parent(s):
1c61b96
add stochastic_depth
Browse files- block.py +26 -14
- modeling_xlm_roberta.py +121 -61
- stochastic_depth.py +97 -0
block.py
CHANGED
@@ -10,8 +10,8 @@ import torch
|
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
from torch import Tensor
|
13 |
-
from torchvision.ops import StochasticDepth
|
14 |
|
|
|
15 |
from .mha import MHA
|
16 |
from .mlp import Mlp
|
17 |
|
@@ -106,7 +106,9 @@ class Block(nn.Module):
|
|
106 |
p._shared_params = True
|
107 |
|
108 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
109 |
-
return self.mixer.allocate_inference_cache(
|
|
|
|
|
110 |
|
111 |
def forward(
|
112 |
self,
|
@@ -152,7 +154,7 @@ class Block(nn.Module):
|
|
152 |
rowscale=rowscale1,
|
153 |
prenorm=True,
|
154 |
residual_in_fp32=self.residual_in_fp32,
|
155 |
-
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
156 |
)
|
157 |
if mixer_kwargs is None:
|
158 |
mixer_kwargs = {}
|
@@ -165,7 +167,9 @@ class Block(nn.Module):
|
|
165 |
if not self.fused_dropout_add_ln:
|
166 |
dropped = self.drop_path2(self.dropout2(hidden_states))
|
167 |
residual = (dropped + residual) if residual is not None else dropped
|
168 |
-
hidden_states = self.norm2(
|
|
|
|
|
169 |
if self.residual_in_fp32:
|
170 |
residual = residual.to(torch.float32)
|
171 |
else:
|
@@ -189,7 +193,7 @@ class Block(nn.Module):
|
|
189 |
rowscale=rowscale2,
|
190 |
prenorm=True,
|
191 |
residual_in_fp32=self.residual_in_fp32,
|
192 |
-
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
193 |
)
|
194 |
hidden_states = self.mlp(hidden_states)
|
195 |
return hidden_states, residual
|
@@ -212,7 +216,9 @@ class Block(nn.Module):
|
|
212 |
else:
|
213 |
rowscale1 = self.drop_path1(
|
214 |
torch.ones(
|
215 |
-
mixer_out.shape[:-1],
|
|
|
|
|
216 |
)
|
217 |
)
|
218 |
hidden_states = layer_norm_fn(
|
@@ -224,7 +230,7 @@ class Block(nn.Module):
|
|
224 |
dropout_p=self.dropout1.p if self.training else 0.0,
|
225 |
rowscale=rowscale1,
|
226 |
prenorm=False,
|
227 |
-
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
228 |
)
|
229 |
if not isinstance(self.mlp, nn.Identity):
|
230 |
mlp_out = self.mlp(hidden_states)
|
@@ -242,7 +248,9 @@ class Block(nn.Module):
|
|
242 |
else:
|
243 |
rowscale2 = self.drop_path2(
|
244 |
torch.ones(
|
245 |
-
mlp_out.shape[:-1],
|
|
|
|
|
246 |
)
|
247 |
)
|
248 |
hidden_states = layer_norm_fn(
|
@@ -254,7 +262,7 @@ class Block(nn.Module):
|
|
254 |
dropout_p=self.dropout2.p if self.training else 0.0,
|
255 |
rowscale=rowscale2,
|
256 |
prenorm=False,
|
257 |
-
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
258 |
)
|
259 |
return hidden_states
|
260 |
|
@@ -333,7 +341,9 @@ class ParallelBlock(nn.Module):
|
|
333 |
p._shared_params = True
|
334 |
|
335 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
336 |
-
return self.mixer.allocate_inference_cache(
|
|
|
|
|
337 |
|
338 |
def forward(
|
339 |
self,
|
@@ -373,7 +383,9 @@ class ParallelBlock(nn.Module):
|
|
373 |
residual = residual.to(torch.float32)
|
374 |
else:
|
375 |
weight2, bias2 = (
|
376 |
-
(self.norm2.weight, self.norm2.bias)
|
|
|
|
|
377 |
)
|
378 |
hidden_states1, *rest, residual = layer_norm_fn(
|
379 |
hidden_states1,
|
@@ -387,14 +399,14 @@ class ParallelBlock(nn.Module):
|
|
387 |
dropout_p=self.dropout1.p if self.training else 0.0,
|
388 |
prenorm=True,
|
389 |
residual_in_fp32=self.residual_in_fp32,
|
390 |
-
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
391 |
)
|
392 |
if self.tied_norm:
|
393 |
hidden_states2 = hidden_states1
|
394 |
else:
|
395 |
-
hidden_states2, = rest
|
396 |
if mixer_kwargs is None:
|
397 |
mixer_kwargs = {}
|
398 |
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
|
399 |
hidden_states2 = self.mlp(hidden_states2)
|
400 |
-
return hidden_states1, hidden_states2, residual
|
|
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
from torch import Tensor
|
|
|
13 |
|
14 |
+
from .stochastic_depth import StochasticDepth
|
15 |
from .mha import MHA
|
16 |
from .mlp import Mlp
|
17 |
|
|
|
106 |
p._shared_params = True
|
107 |
|
108 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
109 |
+
return self.mixer.allocate_inference_cache(
|
110 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
111 |
+
)
|
112 |
|
113 |
def forward(
|
114 |
self,
|
|
|
154 |
rowscale=rowscale1,
|
155 |
prenorm=True,
|
156 |
residual_in_fp32=self.residual_in_fp32,
|
157 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
158 |
)
|
159 |
if mixer_kwargs is None:
|
160 |
mixer_kwargs = {}
|
|
|
167 |
if not self.fused_dropout_add_ln:
|
168 |
dropped = self.drop_path2(self.dropout2(hidden_states))
|
169 |
residual = (dropped + residual) if residual is not None else dropped
|
170 |
+
hidden_states = self.norm2(
|
171 |
+
residual.to(dtype=self.norm2.weight.dtype)
|
172 |
+
)
|
173 |
if self.residual_in_fp32:
|
174 |
residual = residual.to(torch.float32)
|
175 |
else:
|
|
|
193 |
rowscale=rowscale2,
|
194 |
prenorm=True,
|
195 |
residual_in_fp32=self.residual_in_fp32,
|
196 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
197 |
)
|
198 |
hidden_states = self.mlp(hidden_states)
|
199 |
return hidden_states, residual
|
|
|
216 |
else:
|
217 |
rowscale1 = self.drop_path1(
|
218 |
torch.ones(
|
219 |
+
mixer_out.shape[:-1],
|
220 |
+
device=mixer_out.device,
|
221 |
+
dtype=mixer_out.dtype,
|
222 |
)
|
223 |
)
|
224 |
hidden_states = layer_norm_fn(
|
|
|
230 |
dropout_p=self.dropout1.p if self.training else 0.0,
|
231 |
rowscale=rowscale1,
|
232 |
prenorm=False,
|
233 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
234 |
)
|
235 |
if not isinstance(self.mlp, nn.Identity):
|
236 |
mlp_out = self.mlp(hidden_states)
|
|
|
248 |
else:
|
249 |
rowscale2 = self.drop_path2(
|
250 |
torch.ones(
|
251 |
+
mlp_out.shape[:-1],
|
252 |
+
device=mlp_out.device,
|
253 |
+
dtype=mlp_out.dtype,
|
254 |
)
|
255 |
)
|
256 |
hidden_states = layer_norm_fn(
|
|
|
262 |
dropout_p=self.dropout2.p if self.training else 0.0,
|
263 |
rowscale=rowscale2,
|
264 |
prenorm=False,
|
265 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
266 |
)
|
267 |
return hidden_states
|
268 |
|
|
|
341 |
p._shared_params = True
|
342 |
|
343 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
344 |
+
return self.mixer.allocate_inference_cache(
|
345 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
346 |
+
)
|
347 |
|
348 |
def forward(
|
349 |
self,
|
|
|
383 |
residual = residual.to(torch.float32)
|
384 |
else:
|
385 |
weight2, bias2 = (
|
386 |
+
(self.norm2.weight, self.norm2.bias)
|
387 |
+
if not self.tied_norm
|
388 |
+
else (None, None)
|
389 |
)
|
390 |
hidden_states1, *rest, residual = layer_norm_fn(
|
391 |
hidden_states1,
|
|
|
399 |
dropout_p=self.dropout1.p if self.training else 0.0,
|
400 |
prenorm=True,
|
401 |
residual_in_fp32=self.residual_in_fp32,
|
402 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
403 |
)
|
404 |
if self.tied_norm:
|
405 |
hidden_states2 = hidden_states1
|
406 |
else:
|
407 |
+
(hidden_states2,) = rest
|
408 |
if mixer_kwargs is None:
|
409 |
mixer_kwargs = {}
|
410 |
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
|
411 |
hidden_states2 = self.mlp(hidden_states2)
|
412 |
+
return hidden_states1, hidden_states2, residual
|
modeling_xlm_roberta.py
CHANGED
@@ -42,6 +42,7 @@ from .block import Block
|
|
42 |
from .embedding import XLMRobertaEmbeddings
|
43 |
from .mha import MHA
|
44 |
from .mlp import FusedMLP, Mlp
|
|
|
45 |
|
46 |
|
47 |
try:
|
@@ -69,10 +70,16 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
69 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
70 |
rotary_kwargs = {}
|
71 |
if config.position_embedding_type == "rotary":
|
72 |
-
rotary_kwargs["rotary_emb_dim"] = getattr(
|
|
|
|
|
73 |
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
|
74 |
-
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
mixer_cls = partial(
|
77 |
MHA,
|
78 |
num_heads=config.num_attention_heads,
|
@@ -183,7 +190,9 @@ class XLMRobertaEncoder(nn.Module):
|
|
183 |
"""
|
184 |
if key_padding_mask is None or not self.use_flash_attn:
|
185 |
mixer_kwargs = (
|
186 |
-
{"key_padding_mask": key_padding_mask.bool()}
|
|
|
|
|
187 |
)
|
188 |
for layer in self.layers:
|
189 |
if self._grad_checkpointing:
|
@@ -191,7 +200,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
191 |
layer,
|
192 |
hidden_states,
|
193 |
use_reentrant=False,
|
194 |
-
mixer_kwargs=mixer_kwargs
|
195 |
)
|
196 |
else:
|
197 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
@@ -210,7 +219,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
210 |
layer,
|
211 |
hidden_states,
|
212 |
use_reentrant=False,
|
213 |
-
mixer_kwargs=mixer_kwargs
|
214 |
)
|
215 |
else:
|
216 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
@@ -222,7 +231,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
222 |
layer,
|
223 |
hidden_states,
|
224 |
use_reentrant=False,
|
225 |
-
mixer_kwargs=mixer_kwargs
|
226 |
)
|
227 |
else:
|
228 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
@@ -230,15 +239,19 @@ class XLMRobertaEncoder(nn.Module):
|
|
230 |
subset_idx = torch.nonzero(
|
231 |
subset_mask[key_padding_mask], as_tuple=False
|
232 |
).flatten()
|
233 |
-
subset_seqlens = (subset_mask & key_padding_mask).sum(
|
|
|
|
|
234 |
subset_cu_seqlens = F.pad(
|
235 |
-
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
|
|
236 |
)
|
237 |
else:
|
238 |
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
239 |
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
240 |
subset_cu_seqlens = F.pad(
|
241 |
-
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
|
|
242 |
)
|
243 |
hidden_states_subset, hidden_states = index_first_axis_residual(
|
244 |
hidden_states, subset_idx
|
@@ -256,10 +269,12 @@ class XLMRobertaEncoder(nn.Module):
|
|
256 |
self.layers[-1],
|
257 |
hidden_states_subset,
|
258 |
use_reentrant=False,
|
259 |
-
mixer_kwargs=mixer_kwargs
|
260 |
)
|
261 |
else:
|
262 |
-
hidden_states = self.layers[-1](
|
|
|
|
|
263 |
return hidden_states
|
264 |
|
265 |
|
@@ -308,7 +323,10 @@ class XLMRobertaPredictionHeadTransform(nn.Module):
|
|
308 |
hidden_states = self.layer_norm(hidden_states)
|
309 |
else:
|
310 |
hidden_states = layer_norm_fn(
|
311 |
-
hidden_states,
|
|
|
|
|
|
|
312 |
)
|
313 |
return hidden_states
|
314 |
|
@@ -349,6 +367,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|
349 |
"""An abstract class to handle weights initialization and
|
350 |
a simple interface for dowloading and loading pretrained models.
|
351 |
"""
|
|
|
352 |
config_class = XLMRobertaFlashConfig
|
353 |
base_model_prefix = "roberta"
|
354 |
supports_gradient_checkpointing = True
|
@@ -358,7 +377,6 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|
358 |
module.gradient_checkpointing = value
|
359 |
|
360 |
|
361 |
-
|
362 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
363 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
364 |
super().__init__(config)
|
@@ -370,7 +388,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
370 |
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
371 |
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
372 |
raise ImportError("Triton is not installed")
|
373 |
-
assert config.hidden_act in [
|
|
|
|
|
|
|
|
|
|
|
374 |
|
375 |
self.embeddings = XLMRobertaEmbeddings(
|
376 |
config.hidden_size,
|
@@ -386,7 +409,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
386 |
|
387 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
388 |
|
389 |
-
|
390 |
def forward(
|
391 |
self,
|
392 |
input_ids,
|
@@ -406,9 +428,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
406 |
if kwargs:
|
407 |
for key, value in kwargs.items():
|
408 |
if value is not None:
|
409 |
-
logger.warning(
|
|
|
|
|
|
|
410 |
|
411 |
-
return_dict =
|
|
|
|
|
412 |
|
413 |
hidden_states = self.embeddings(
|
414 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
@@ -439,17 +466,23 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
439 |
)
|
440 |
|
441 |
if masked_tokens_mask is None:
|
442 |
-
pooled_output =
|
|
|
|
|
443 |
else:
|
444 |
# TD [2022-03-01]: the indexing here is very tricky.
|
445 |
if attention_mask is not None:
|
446 |
subset_idx = subset_mask[attention_mask]
|
447 |
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
|
448 |
-
sequence_output = sequence_output[
|
|
|
|
|
449 |
else:
|
450 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
451 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
452 |
-
pooled_output =
|
|
|
|
|
453 |
|
454 |
if not return_dict:
|
455 |
return sequence_output, pooled_output
|
@@ -487,7 +520,6 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
|
487 |
def set_output_embeddings(self, new_embeddings):
|
488 |
self.lm_head.decoder = new_embeddings
|
489 |
|
490 |
-
|
491 |
def forward(
|
492 |
self,
|
493 |
input_ids: Optional[torch.LongTensor] = None,
|
@@ -511,7 +543,9 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
|
511 |
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
512 |
Used to hide legacy arguments that have been deprecated.
|
513 |
"""
|
514 |
-
return_dict =
|
|
|
|
|
515 |
|
516 |
outputs = self.roberta(
|
517 |
input_ids,
|
@@ -534,11 +568,15 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
|
534 |
# move labels to correct device to enable model parallelism
|
535 |
labels = labels.to(prediction_scores.device)
|
536 |
loss_fct = CrossEntropyLoss()
|
537 |
-
masked_lm_loss = loss_fct(
|
|
|
|
|
538 |
|
539 |
if not return_dict:
|
540 |
output = (prediction_scores,) + outputs[2:]
|
541 |
-
return (
|
|
|
|
|
542 |
|
543 |
return MaskedLMOutput(
|
544 |
loss=masked_lm_loss,
|
@@ -656,7 +694,9 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
|
|
656 |
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
657 |
return key
|
658 |
|
659 |
-
state_dict = OrderedDict(
|
|
|
|
|
660 |
|
661 |
# Layers
|
662 |
def key_mapping_layers(key):
|
@@ -715,12 +755,18 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
|
|
715 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
|
716 |
[Wq, Wk, Wv], dim=0
|
717 |
)
|
718 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
|
|
|
|
|
719 |
else:
|
720 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
|
721 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
|
|
|
|
|
722 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
|
723 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
|
|
|
|
|
724 |
|
725 |
def key_mapping_attn(key):
|
726 |
return re.sub(
|
@@ -734,7 +780,9 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
|
|
734 |
def key_mapping_decoder_bias(key):
|
735 |
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
736 |
|
737 |
-
state_dict = OrderedDict(
|
|
|
|
|
738 |
|
739 |
# Word embedding
|
740 |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
@@ -774,51 +822,59 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
|
774 |
state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
|
775 |
: config.orig_vocab_size, :
|
776 |
]
|
777 |
-
state_dict["cls.predictions.decoder.weight"] = decoder_weight[
|
778 |
-
|
|
|
|
|
|
|
|
|
779 |
|
780 |
for d in range(config.num_hidden_layers):
|
781 |
last_layer_subset = getattr(config, "last_layer_subset", False)
|
782 |
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
783 |
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
784 |
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
785 |
-
state_dict[
|
786 |
-
|
787 |
-
]
|
788 |
-
state_dict[
|
|
|
|
|
789 |
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
|
790 |
]
|
791 |
-
state_dict[
|
792 |
-
|
793 |
-
]
|
794 |
-
state_dict[
|
795 |
-
|
796 |
-
]
|
797 |
-
state_dict[
|
798 |
-
|
799 |
-
]
|
800 |
-
state_dict[
|
801 |
-
|
802 |
-
]
|
803 |
else:
|
804 |
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
805 |
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
806 |
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
807 |
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
808 |
-
state_dict[
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
|
|
|
|
815 |
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
816 |
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
817 |
: Wkv_biases.shape[0] // 2
|
818 |
]
|
819 |
-
state_dict[
|
820 |
-
|
821 |
-
]
|
822 |
|
823 |
def inv_key_mapping_ln(key):
|
824 |
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
@@ -870,14 +926,18 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
|
870 |
def inv_key_mapping_decoder_bias(key):
|
871 |
return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
|
872 |
|
873 |
-
state_dict = OrderedDict(
|
|
|
|
|
874 |
state_dict = OrderedDict(
|
875 |
(inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
|
876 |
)
|
877 |
state_dict = OrderedDict(
|
878 |
(inv_key_mapping_layers(key), value) for key, value in state_dict.items()
|
879 |
)
|
880 |
-
state_dict = OrderedDict(
|
|
|
|
|
881 |
state_dict = OrderedDict(
|
882 |
(inv_key_mapping_attn(key), value) for key, value in state_dict.items()
|
883 |
)
|
@@ -885,4 +945,4 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
|
885 |
(inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
|
886 |
)
|
887 |
|
888 |
-
return state_dict
|
|
|
42 |
from .embedding import XLMRobertaEmbeddings
|
43 |
from .mha import MHA
|
44 |
from .mlp import FusedMLP, Mlp
|
45 |
+
from .stochastic_depth import StochasticDepth
|
46 |
|
47 |
|
48 |
try:
|
|
|
70 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
71 |
rotary_kwargs = {}
|
72 |
if config.position_embedding_type == "rotary":
|
73 |
+
rotary_kwargs["rotary_emb_dim"] = getattr(
|
74 |
+
config, "rotary_emb_dim", config.hidden_size
|
75 |
+
)
|
76 |
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
|
77 |
+
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
78 |
+
config, "rotary_emb_scale_base", None
|
79 |
+
)
|
80 |
+
rotary_kwargs["rotary_emb_interleaved"] = getattr(
|
81 |
+
config, "rotary_emb_interleaved", False
|
82 |
+
)
|
83 |
mixer_cls = partial(
|
84 |
MHA,
|
85 |
num_heads=config.num_attention_heads,
|
|
|
190 |
"""
|
191 |
if key_padding_mask is None or not self.use_flash_attn:
|
192 |
mixer_kwargs = (
|
193 |
+
{"key_padding_mask": key_padding_mask.bool()}
|
194 |
+
if key_padding_mask is not None
|
195 |
+
else None
|
196 |
)
|
197 |
for layer in self.layers:
|
198 |
if self._grad_checkpointing:
|
|
|
200 |
layer,
|
201 |
hidden_states,
|
202 |
use_reentrant=False,
|
203 |
+
mixer_kwargs=mixer_kwargs,
|
204 |
)
|
205 |
else:
|
206 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
219 |
layer,
|
220 |
hidden_states,
|
221 |
use_reentrant=False,
|
222 |
+
mixer_kwargs=mixer_kwargs,
|
223 |
)
|
224 |
else:
|
225 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
231 |
layer,
|
232 |
hidden_states,
|
233 |
use_reentrant=False,
|
234 |
+
mixer_kwargs=mixer_kwargs,
|
235 |
)
|
236 |
else:
|
237 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
239 |
subset_idx = torch.nonzero(
|
240 |
subset_mask[key_padding_mask], as_tuple=False
|
241 |
).flatten()
|
242 |
+
subset_seqlens = (subset_mask & key_padding_mask).sum(
|
243 |
+
dim=-1, dtype=torch.int32
|
244 |
+
)
|
245 |
subset_cu_seqlens = F.pad(
|
246 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
247 |
+
(1, 0),
|
248 |
)
|
249 |
else:
|
250 |
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
251 |
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
252 |
subset_cu_seqlens = F.pad(
|
253 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
254 |
+
(1, 0),
|
255 |
)
|
256 |
hidden_states_subset, hidden_states = index_first_axis_residual(
|
257 |
hidden_states, subset_idx
|
|
|
269 |
self.layers[-1],
|
270 |
hidden_states_subset,
|
271 |
use_reentrant=False,
|
272 |
+
mixer_kwargs=mixer_kwargs,
|
273 |
)
|
274 |
else:
|
275 |
+
hidden_states = self.layers[-1](
|
276 |
+
hidden_states_subset, mixer_kwargs=mixer_kwargs
|
277 |
+
)
|
278 |
return hidden_states
|
279 |
|
280 |
|
|
|
323 |
hidden_states = self.layer_norm(hidden_states)
|
324 |
else:
|
325 |
hidden_states = layer_norm_fn(
|
326 |
+
hidden_states,
|
327 |
+
self.layer_norm.weight,
|
328 |
+
self.layer_norm.bias,
|
329 |
+
eps=self.layer_norm.eps,
|
330 |
)
|
331 |
return hidden_states
|
332 |
|
|
|
367 |
"""An abstract class to handle weights initialization and
|
368 |
a simple interface for dowloading and loading pretrained models.
|
369 |
"""
|
370 |
+
|
371 |
config_class = XLMRobertaFlashConfig
|
372 |
base_model_prefix = "roberta"
|
373 |
supports_gradient_checkpointing = True
|
|
|
377 |
module.gradient_checkpointing = value
|
378 |
|
379 |
|
|
|
380 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
381 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
382 |
super().__init__(config)
|
|
|
388 |
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
389 |
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
390 |
raise ImportError("Triton is not installed")
|
391 |
+
assert config.hidden_act in [
|
392 |
+
"gelu",
|
393 |
+
"gelu_new",
|
394 |
+
"gelu_fast",
|
395 |
+
"gelu_pytorch_tanh",
|
396 |
+
]
|
397 |
|
398 |
self.embeddings = XLMRobertaEmbeddings(
|
399 |
config.hidden_size,
|
|
|
409 |
|
410 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
411 |
|
|
|
412 |
def forward(
|
413 |
self,
|
414 |
input_ids,
|
|
|
428 |
if kwargs:
|
429 |
for key, value in kwargs.items():
|
430 |
if value is not None:
|
431 |
+
logger.warning(
|
432 |
+
'Flash attention implementation does not support kwargs: %s',
|
433 |
+
key,
|
434 |
+
)
|
435 |
|
436 |
+
return_dict = (
|
437 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
438 |
+
)
|
439 |
|
440 |
hidden_states = self.embeddings(
|
441 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
|
|
466 |
)
|
467 |
|
468 |
if masked_tokens_mask is None:
|
469 |
+
pooled_output = (
|
470 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
471 |
+
)
|
472 |
else:
|
473 |
# TD [2022-03-01]: the indexing here is very tricky.
|
474 |
if attention_mask is not None:
|
475 |
subset_idx = subset_mask[attention_mask]
|
476 |
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
|
477 |
+
sequence_output = sequence_output[
|
478 |
+
masked_tokens_mask[attention_mask][subset_idx]
|
479 |
+
]
|
480 |
else:
|
481 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
482 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
483 |
+
pooled_output = (
|
484 |
+
self.pooler(pool_input, pool=False) if self.pooler is not None else None
|
485 |
+
)
|
486 |
|
487 |
if not return_dict:
|
488 |
return sequence_output, pooled_output
|
|
|
520 |
def set_output_embeddings(self, new_embeddings):
|
521 |
self.lm_head.decoder = new_embeddings
|
522 |
|
|
|
523 |
def forward(
|
524 |
self,
|
525 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
543 |
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
544 |
Used to hide legacy arguments that have been deprecated.
|
545 |
"""
|
546 |
+
return_dict = (
|
547 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
548 |
+
)
|
549 |
|
550 |
outputs = self.roberta(
|
551 |
input_ids,
|
|
|
568 |
# move labels to correct device to enable model parallelism
|
569 |
labels = labels.to(prediction_scores.device)
|
570 |
loss_fct = CrossEntropyLoss()
|
571 |
+
masked_lm_loss = loss_fct(
|
572 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
|
573 |
+
)
|
574 |
|
575 |
if not return_dict:
|
576 |
output = (prediction_scores,) + outputs[2:]
|
577 |
+
return (
|
578 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
579 |
+
)
|
580 |
|
581 |
return MaskedLMOutput(
|
582 |
loss=masked_lm_loss,
|
|
|
694 |
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
695 |
return key
|
696 |
|
697 |
+
state_dict = OrderedDict(
|
698 |
+
(key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()
|
699 |
+
)
|
700 |
|
701 |
# Layers
|
702 |
def key_mapping_layers(key):
|
|
|
755 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
|
756 |
[Wq, Wk, Wv], dim=0
|
757 |
)
|
758 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
|
759 |
+
[bq, bk, bv], dim=0
|
760 |
+
)
|
761 |
else:
|
762 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
|
763 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
|
764 |
+
[Wk, Wv], dim=0
|
765 |
+
)
|
766 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
|
767 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
|
768 |
+
[bk, bv], dim=0
|
769 |
+
)
|
770 |
|
771 |
def key_mapping_attn(key):
|
772 |
return re.sub(
|
|
|
780 |
def key_mapping_decoder_bias(key):
|
781 |
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
782 |
|
783 |
+
state_dict = OrderedDict(
|
784 |
+
(key_mapping_decoder_bias(k), v) for k, v in state_dict.items()
|
785 |
+
)
|
786 |
|
787 |
# Word embedding
|
788 |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
|
|
822 |
state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
|
823 |
: config.orig_vocab_size, :
|
824 |
]
|
825 |
+
state_dict["cls.predictions.decoder.weight"] = decoder_weight[
|
826 |
+
: config.orig_vocab_size, :
|
827 |
+
]
|
828 |
+
state_dict["cls.predictions.decoder.bias"] = decoder_bias[
|
829 |
+
: config.orig_vocab_size
|
830 |
+
]
|
831 |
|
832 |
for d in range(config.num_hidden_layers):
|
833 |
last_layer_subset = getattr(config, "last_layer_subset", False)
|
834 |
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
835 |
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
836 |
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
837 |
+
state_dict[
|
838 |
+
f"bert.encoder.layers.{d}.attention.self.query.weight"
|
839 |
+
] = Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
|
840 |
+
state_dict[
|
841 |
+
f"bert.encoder.layers.{d}.attention.self.key.weight"
|
842 |
+
] = Wqkv_weights[
|
843 |
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
|
844 |
]
|
845 |
+
state_dict[
|
846 |
+
f"bert.encoder.layers.{d}.attention.self.value.weight"
|
847 |
+
] = Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
|
848 |
+
state_dict[
|
849 |
+
f"bert.encoder.layers.{d}.attention.self.query.bias"
|
850 |
+
] = Wqkv_biases[: Wqkv_biases.shape[0] // 3]
|
851 |
+
state_dict[
|
852 |
+
f"bert.encoder.layers.{d}.attention.self.key.bias"
|
853 |
+
] = Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
|
854 |
+
state_dict[
|
855 |
+
f"bert.encoder.layers.{d}.attention.self.value.bias"
|
856 |
+
] = Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
|
857 |
else:
|
858 |
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
859 |
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
860 |
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
861 |
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
862 |
+
state_dict[
|
863 |
+
f"bert.encoder.layers.{d}.attention.self.query.weight"
|
864 |
+
] = Wq_weight
|
865 |
+
state_dict[
|
866 |
+
f"bert.encoder.layers.{d}.attention.self.key.weight"
|
867 |
+
] = Wkv_weights[: Wkv_weights.shape[0] // 2, :]
|
868 |
+
state_dict[
|
869 |
+
f"bert.encoder.layers.{d}.attention.self.value.weight"
|
870 |
+
] = Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
|
871 |
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
872 |
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
873 |
: Wkv_biases.shape[0] // 2
|
874 |
]
|
875 |
+
state_dict[
|
876 |
+
f"bert.encoder.layers.{d}.attention.self.value.bias"
|
877 |
+
] = Wkv_biases[Wkv_biases.shape[0] // 2 :]
|
878 |
|
879 |
def inv_key_mapping_ln(key):
|
880 |
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
|
|
926 |
def inv_key_mapping_decoder_bias(key):
|
927 |
return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
|
928 |
|
929 |
+
state_dict = OrderedDict(
|
930 |
+
(inv_key_mapping_ln(key), value) for key, value in state_dict.items()
|
931 |
+
)
|
932 |
state_dict = OrderedDict(
|
933 |
(inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
|
934 |
)
|
935 |
state_dict = OrderedDict(
|
936 |
(inv_key_mapping_layers(key), value) for key, value in state_dict.items()
|
937 |
)
|
938 |
+
state_dict = OrderedDict(
|
939 |
+
(inv_key_mapping_mlp(key), value) for key, value in state_dict.items()
|
940 |
+
)
|
941 |
state_dict = OrderedDict(
|
942 |
(inv_key_mapping_attn(key), value) for key, value in state_dict.items()
|
943 |
)
|
|
|
945 |
(inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
|
946 |
)
|
947 |
|
948 |
+
return state_dict
|
stochastic_depth.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation modified from torchvision:
|
2 |
+
# https://github.com/pytorch/vision/blob/main/torchvision/ops/stochastic_depth.py
|
3 |
+
#
|
4 |
+
# License:
|
5 |
+
# BSD 3-Clause License
|
6 |
+
#
|
7 |
+
# Copyright (c) Soumith Chintala 2016,
|
8 |
+
# All rights reserved.
|
9 |
+
#
|
10 |
+
# Redistribution and use in source and binary forms, with or without
|
11 |
+
# modification, are permitted provided that the following conditions are met:
|
12 |
+
#
|
13 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
14 |
+
# list of conditions and the following disclaimer.
|
15 |
+
#
|
16 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
17 |
+
# this list of conditions and the following disclaimer in the documentation
|
18 |
+
# and/or other materials provided with the distribution.
|
19 |
+
#
|
20 |
+
# * Neither the name of the copyright holder nor the names of its
|
21 |
+
# contributors may be used to endorse or promote products derived from
|
22 |
+
# this software without specific prior written permission.
|
23 |
+
#
|
24 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
25 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
26 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
27 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
28 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
29 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
30 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
31 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
32 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
33 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
34 |
+
|
35 |
+
import torch
|
36 |
+
import torch.fx
|
37 |
+
from torch import nn, Tensor
|
38 |
+
|
39 |
+
|
40 |
+
def stochastic_depth(
|
41 |
+
input: Tensor, p: float, mode: str, training: bool = True
|
42 |
+
) -> Tensor:
|
43 |
+
"""
|
44 |
+
Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
|
45 |
+
<https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
|
46 |
+
branches of residual architectures.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
|
50 |
+
being its batch i.e. a batch with ``N`` rows.
|
51 |
+
p (float): probability of the input to be zeroed.
|
52 |
+
mode (str): ``"batch"`` or ``"row"``.
|
53 |
+
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
|
54 |
+
randomly selected rows from the batch.
|
55 |
+
training: apply stochastic depth if is ``True``. Default: ``True``
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Tensor[N, ...]: The randomly zeroed tensor.
|
59 |
+
"""
|
60 |
+
if p < 0.0 or p > 1.0:
|
61 |
+
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
|
62 |
+
if mode not in ["batch", "row"]:
|
63 |
+
raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
|
64 |
+
if not training or p == 0.0:
|
65 |
+
return input
|
66 |
+
|
67 |
+
survival_rate = 1.0 - p
|
68 |
+
if mode == "row":
|
69 |
+
size = [input.shape[0]] + [1] * (input.ndim - 1)
|
70 |
+
else:
|
71 |
+
size = [1] * input.ndim
|
72 |
+
noise = torch.empty(size, dtype=input.dtype, device=input.device)
|
73 |
+
noise = noise.bernoulli_(survival_rate)
|
74 |
+
if survival_rate > 0.0:
|
75 |
+
noise.div_(survival_rate)
|
76 |
+
return input * noise
|
77 |
+
|
78 |
+
|
79 |
+
torch.fx.wrap("stochastic_depth")
|
80 |
+
|
81 |
+
|
82 |
+
class StochasticDepth(nn.Module):
|
83 |
+
"""
|
84 |
+
See :func:`stochastic_depth`.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, p: float, mode: str) -> None:
|
88 |
+
super().__init__()
|
89 |
+
self.p = p
|
90 |
+
self.mode = mode
|
91 |
+
|
92 |
+
def forward(self, input: Tensor) -> Tensor:
|
93 |
+
return stochastic_depth(input, self.p, self.mode, self.training)
|
94 |
+
|
95 |
+
def __repr__(self) -> str:
|
96 |
+
s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
|
97 |
+
return s
|