NohTow commited on
Commit
2571cc4
1 Parent(s): 63f779e

Adding base modeling

Browse files
__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import sys
4
+
5
+ # Add src folder root to path to allow us to use relative imports regardless of what directory the script is run from
6
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
7
+ from modeling_flexbert import FlexBertModel
__pycache__/__init__.cpython-311.pyc ADDED
Binary file (517 Bytes). View file
 
__pycache__/activation.cpython-311.pyc ADDED
Binary file (3.3 kB). View file
 
__pycache__/attention.cpython-311.pyc ADDED
Binary file (65.2 kB). View file
 
__pycache__/bert_padding.cpython-311.pyc ADDED
Binary file (7.68 kB). View file
 
__pycache__/configuration_bert.cpython-311.pyc ADDED
Binary file (15.5 kB). View file
 
__pycache__/embeddings.cpython-311.pyc ADDED
Binary file (14 kB). View file
 
__pycache__/initialization.cpython-311.pyc ADDED
Binary file (24.8 kB). View file
 
__pycache__/layers.cpython-311.pyc ADDED
Binary file (39.6 kB). View file
 
__pycache__/mlp.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
__pycache__/modeling_flexbert.cpython-311.pyc ADDED
Binary file (77.6 kB). View file
 
__pycache__/normalization.cpython-311.pyc ADDED
Binary file (5.52 kB). View file
 
__pycache__/padding.cpython-311.pyc ADDED
Binary file (4.45 kB). View file
 
__pycache__/rotary.cpython-311.pyc ADDED
Binary file (12.4 kB). View file
 
__pycache__/utils.cpython-311.pyc ADDED
Binary file (1.89 kB). View file
 
activation.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2020 The HuggingFace Team.
5
+ # License: Apache-2.0
6
+
7
+ from collections import OrderedDict
8
+ from typing import Union
9
+ import torch.nn as nn
10
+ from configuration_bert import FlexBertConfig
11
+
12
+
13
+ class ClassInstantier(OrderedDict):
14
+ def __getitem__(self, key):
15
+ content = super().__getitem__(key)
16
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
17
+ return cls(**kwargs)
18
+
19
+
20
+ ACT2CLS = {
21
+ "celu": nn.CELU,
22
+ "elu": nn.ELU,
23
+ "gelu": nn.GELU,
24
+ "gelu_tanh": (nn.GELU, {"approximate": "tanh"}),
25
+ "hardtanh": nn.Hardtanh,
26
+ "hardsigmoid": nn.Hardsigmoid,
27
+ "hardshrink": nn.Hardshrink,
28
+ "hardswish": nn.Hardswish,
29
+ "leaky_relu": nn.LeakyReLU,
30
+ "logsigmoid": nn.LogSigmoid,
31
+ "mish": nn.Mish,
32
+ "prelu": nn.PReLU,
33
+ "relu": nn.ReLU,
34
+ "relu6": nn.ReLU6,
35
+ "rrelu": nn.RReLU,
36
+ "selu": nn.SELU,
37
+ "sigmoid": nn.Sigmoid,
38
+ "silu": nn.SiLU,
39
+ "softmin": nn.Softmin,
40
+ "softplus": nn.Softplus,
41
+ "softshrink": nn.Softshrink,
42
+ "softsign": nn.Softsign,
43
+ "swish": nn.SiLU,
44
+ "tanh": nn.Tanh,
45
+ "tanhshrink": nn.Tanhshrink,
46
+ "threshold": nn.Threshold,
47
+ }
48
+ ACT2FN = ClassInstantier(ACT2CLS)
49
+
50
+
51
+ def get_act_fn(config: Union[FlexBertConfig, str]) -> nn.Module:
52
+ try:
53
+ if isinstance(config, str):
54
+ return ACT2FN[config]
55
+ return ACT2FN[config.hidden_act]
56
+ except KeyError:
57
+ if isinstance(config, str):
58
+ raise ValueError(f"Invalid activation function type: {config}, must be one of {ACT2FN.keys()}.")
59
+ else:
60
+ raise ValueError(f"Invalid activation function type: {config.hidden_act=}, must be one of {ACT2FN.keys()}.")
attention.py ADDED
@@ -0,0 +1,1563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2022 MosaicML Examples authors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+
7
+ # Copyright 2023 MosaicML Examples authors
8
+ # SPDX-License-Identifier: Apache-2.0
9
+
10
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
11
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
12
+ # Copyright (c) 2023, Tri Dao.
13
+
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import warnings
19
+ from typing import Optional
20
+ import importlib.metadata
21
+ import logging
22
+ import math
23
+
24
+ import bert_padding
25
+ from configuration_bert import FlexBertConfig, maybe_add_padding
26
+ from normalization import get_norm_layer
27
+ from initialization import ModuleType, init_weights
28
+ import utils # noqa: F401
29
+
30
+ IMPL_USE_FLASH3 = False
31
+ IMPL_USE_FLASH2 = False
32
+ try:
33
+ from flash_attn_interface import flash_attn_varlen_func
34
+
35
+ IMPL_USE_FLASH3 = True
36
+ except ImportError:
37
+ pass
38
+ # Import Flash Attention 2, which supports ALiBi https://github.com/Dao-AILab/flash-attention
39
+ try:
40
+ from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func # type: ignore
41
+
42
+ installed_version = importlib.metadata.version("flash_attn") # type: ignore
43
+ if installed_version < "2.5.7":
44
+ raise ImportError("newer version of flash_attn required (>= 2.5.7)")
45
+ IMPL_USE_FLASH2 = True
46
+ except ImportError:
47
+ pass
48
+
49
+ try:
50
+ from flash_attn.layers.rotary import RotaryEmbedding # type: ignore
51
+ from rotary import UnpaddedRotaryEmbedding # type: ignore
52
+
53
+ except ImportError:
54
+ RotaryEmbedding = None
55
+ UnpaddedRotaryEmbedding = None
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ class BertAlibiUnpadSelfAttention(nn.Module):
61
+ """Performs multi-headed self attention on a batch of unpadded sequences.
62
+
63
+ If Flash Attention 2 is installed, this module uses Flash Attention to greatly improve throughput.
64
+ The Flash Attention implementation used in MosaicBERT supports arbitrary attention biases (which
65
+ we use to implement ALiBi). If either Flash Attention 2 is not installed the implementation will
66
+ default to a math-equivalent pytorch version, which is much slower.
67
+
68
+ See `forward` method for additional details.
69
+ """
70
+
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
74
+ raise ValueError(
75
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
76
+ f"heads ({config.num_attention_heads})"
77
+ )
78
+
79
+ self.num_attention_heads = config.num_attention_heads
80
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
81
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
82
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
83
+ self.p_dropout = config.attention_probs_dropout_prob
84
+ self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
85
+ self.deterministic_fa2 = getattr(config, "deterministic_fa2", False)
86
+
87
+ # Warn if defaulting to pytorch because of import issues
88
+ if not IMPL_USE_FLASH2:
89
+ warnings.warn(
90
+ "Unable to import flash_attn; defaulting MosaicBERT attention implementation to "
91
+ "vanilla PyTorch (this will reduce throughput when using this model)."
92
+ )
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ cu_seqlens: torch.Tensor,
98
+ max_seqlen: int,
99
+ indices: torch.Tensor,
100
+ attn_mask: torch.Tensor,
101
+ bias: torch.Tensor,
102
+ slopes: torch.Tensor,
103
+ ) -> torch.Tensor:
104
+ """Perform self-attention.
105
+
106
+ There are two attention implementations: vanilla attention with ALiBi, and Flash Attention 2 with ALiBi
107
+
108
+ The arguments are unpadded. The vanilla implementation of attention requires padded arguments while the
109
+ Flash Attention implementation does not. If using vanilla we first call `pad_input`. Once we compute
110
+ attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not
111
+ sending pad tokens through ffs saves compute.
112
+
113
+ Args:
114
+ hidden_states: (total_nnz, dim)
115
+ cu_seqlens: (batch + 1,)
116
+ max_seqlen: int
117
+ indices: (total_nnz,)
118
+ attn_mask: (batch, max_seqlen)
119
+ bias: (batch, heads, max_seqlen, max_seqlen)
120
+ slopes: (heads) or (batch, heads)
121
+
122
+ Returns:
123
+ attention: (total_nnz, dim)
124
+ """
125
+ bs, dim = hidden_states.shape
126
+ qkv = self.Wqkv(hidden_states)
127
+
128
+ # Option 1: Flash Attention with ALiBi
129
+ if IMPL_USE_FLASH2:
130
+ qkv = qkv.view(-1, 3, self.num_attention_heads, self.attention_head_size)
131
+ assert 1 <= len(slopes.shape) <= 2, f"{slopes=}"
132
+ assert slopes.shape[-1] == self.num_attention_heads, f"{slopes=}"
133
+
134
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
135
+ if convert_dtype:
136
+ # FA2 implementation only supports fp16 and bf16
137
+ # If FA2 is supported, bfloat16 must be supported
138
+ # as of FA2 2.4.2. (Turing GPUs not supported)
139
+ orig_dtype = qkv.dtype
140
+ qkv = qkv.to(torch.bfloat16)
141
+
142
+ attention = flash_attn_varlen_qkvpacked_func(
143
+ qkv,
144
+ cu_seqlens=cu_seqlens,
145
+ max_seqlen=max_seqlen,
146
+ dropout_p=self.p_dropout,
147
+ deterministic=self.deterministic_fa2,
148
+ alibi_slopes=slopes,
149
+ )
150
+ attention = attention.to(orig_dtype) # type: ignore
151
+ else:
152
+ attention = flash_attn_varlen_qkvpacked_func(
153
+ qkv,
154
+ cu_seqlens=cu_seqlens,
155
+ max_seqlen=max_seqlen,
156
+ dropout_p=self.p_dropout,
157
+ deterministic=self.deterministic_fa2,
158
+ alibi_slopes=slopes,
159
+ )
160
+ else:
161
+ qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
162
+ unpad_bs, *_ = qkv.shape
163
+ qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
164
+ # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
165
+ q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
166
+ k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
167
+ v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
168
+ attention_scores = torch.matmul(q, k) / math.sqrt(self.attention_head_size)
169
+ attention_scores = attention_scores + bias
170
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
171
+ attention_probs = self.dropout(attention_probs)
172
+ attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
173
+
174
+ attention = bert_padding.unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
175
+
176
+ return attention.view(bs, dim)
177
+
178
+
179
+ # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
180
+ class BertSelfOutput(nn.Module):
181
+ """Computes the output of the attention layer.
182
+
183
+ This module is modeled after the Hugging Face BERT's
184
+ :class:`~transformers.model.bert.modeling_bert.BertSelfOutput`.
185
+ The implementation is identical. Rather than use the original module
186
+ directly, we re-implement it here so that Mosaic BERT's modules will not
187
+ be affected by any Composer surgery algorithm that modifies Hugging Face
188
+ BERT modules.
189
+ """
190
+
191
+ def __init__(self, config):
192
+ super().__init__()
193
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
194
+ self.LayerNorm = get_norm_layer(config)
195
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
196
+
197
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
198
+ hidden_states = self.dense(hidden_states)
199
+ hidden_states = self.dropout(hidden_states)
200
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
201
+ return hidden_states
202
+
203
+
204
+ class BertAlibiUnpadAttention(nn.Module):
205
+ """Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
206
+
207
+ def __init__(self, config):
208
+ super().__init__()
209
+ self.self = BertAlibiUnpadSelfAttention(config)
210
+ self.output = BertSelfOutput(config)
211
+
212
+ def forward(
213
+ self,
214
+ input_tensor: torch.Tensor,
215
+ cu_seqlens: torch.Tensor,
216
+ max_s: int,
217
+ subset_idx: Optional[torch.Tensor] = None,
218
+ indices: Optional[torch.Tensor] = None,
219
+ attn_mask: Optional[torch.Tensor] = None,
220
+ bias: Optional[torch.Tensor] = None,
221
+ slopes: Optional[torch.Tensor] = None,
222
+ ) -> torch.Tensor:
223
+ """Forward pass for scaled self-attention without padding.
224
+
225
+ Arguments:
226
+ input_tensor: (total_nnz, dim)
227
+ cu_seqlens: (batch + 1,)
228
+ max_s: int
229
+ subset_idx: () set of indices whose values we care about at the end of the layer
230
+ (e.g., the masked tokens, if this is the final layer).
231
+ indices: None or (total_nnz,)
232
+ attn_mask: None or (batch, max_seqlen)
233
+ bias: None or (batch, heads, max_seqlen, max_seqlen)
234
+ slopes: None or (batch, heads) or (heads,)
235
+ """
236
+ assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}"
237
+ self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
238
+ if subset_idx is not None:
239
+ return self.output(
240
+ bert_padding.index_first_axis(self_output, subset_idx),
241
+ bert_padding.index_first_axis(input_tensor, subset_idx),
242
+ )
243
+ else:
244
+ return self.output(self_output, input_tensor)
245
+
246
+
247
+ class FlexBertAttentionBase(nn.Module):
248
+ """A FlexBERT attention base class for type hints."""
249
+
250
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
251
+ super().__init__()
252
+ self.config = config
253
+ self.layer_id = layer_id
254
+
255
+ def _init_weights(self, reset_params: bool = False):
256
+ raise NotImplementedError("This is a base class and should not be used directly.")
257
+
258
+ def forward(self, hidden_states: torch.Tensor, attn_mask: torch.Tensor, **kwargs) -> torch.Tensor:
259
+ raise NotImplementedError("This is a base class and should not be used directly.")
260
+
261
+ def extra_repr(self) -> str:
262
+ repr = ""
263
+ if hasattr(self, "num_attention_heads"):
264
+ repr += f"num_attention_heads={self.num_attention_heads}"
265
+ if hasattr(self, "attn_head_size"):
266
+ repr += f", attn_head_size={self.attn_head_size}"
267
+ if hasattr(self, "sliding_window"):
268
+ repr += f", sliding_window={self.sliding_window if self.sliding_window != (-1, -1) else 'False'}"
269
+ if hasattr(self, "use_fa2"):
270
+ repr += f", use_fa2={self.use_fa2}"
271
+ if hasattr(self, "deterministic_fa2"):
272
+ repr += f", deterministic_fa2={self.deterministic_fa2}"
273
+ return repr
274
+
275
+
276
+ class FlexBertUnpadAttention(FlexBertAttentionBase):
277
+ """Performs multi-headed self attention on a batch of unpadded sequences.
278
+
279
+ If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
280
+ If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
281
+ which requires padding and unpadding inputs, adding some overhead.
282
+
283
+ See `forward` method for additional detail.
284
+ """
285
+
286
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
287
+ super().__init__(config=config, layer_id=layer_id)
288
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
289
+ raise ValueError(
290
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
291
+ f"heads ({config.num_attention_heads})"
292
+ )
293
+
294
+ self.num_attention_heads = config.num_attention_heads
295
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
296
+ self.all_head_size = self.num_attention_heads * self.attn_head_size
297
+ self.p_dropout = config.attention_probs_dropout_prob
298
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias)
299
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
300
+ self.out_drop = (
301
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
302
+ )
303
+ self.use_fa2 = config.use_fa2
304
+ self.deterministic_fa2 = config.deterministic_fa2
305
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
306
+
307
+ if config.global_attn_every_n_layers > 0:
308
+ if config.sliding_window == -1:
309
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
310
+ if layer_id % config.global_attn_every_n_layers != 0:
311
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
312
+ else:
313
+ self.sliding_window = (-1, -1)
314
+ else:
315
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
316
+
317
+ # Warn if defaulting to pytorch because of import issues
318
+ if not IMPL_USE_FLASH2 and self.use_fa2:
319
+ logger.warn_once(
320
+ "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's"
321
+ " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead."
322
+ )
323
+ self.use_fa2 = False
324
+ if not self.use_fa2:
325
+ if not self.use_sdpa_attn_mask:
326
+ logger.warn_once(
327
+ "SDPA attention is being used without an attention mask. Including padding in the "
328
+ " attention calculation may cause differences from the Flash Attention implementation."
329
+ )
330
+ else:
331
+ logger.warn_once(
332
+ "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will"
333
+ " use more memory during the backward pass. Use the FA2 backend for linear memory scaling"
334
+ " with sequence length."
335
+ )
336
+ if self.sliding_window[0] > 0:
337
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
338
+
339
+ def _init_weights(self, reset_params: bool = False):
340
+ init_weights(
341
+ self.config,
342
+ self.Wqkv,
343
+ layer_dim=self.config.hidden_size,
344
+ layer_id=None,
345
+ type_of_module=ModuleType.in_module,
346
+ )
347
+ init_weights(
348
+ self.config,
349
+ self.Wo,
350
+ layer_dim=self.config.hidden_size,
351
+ layer_id=self.layer_id,
352
+ type_of_module=ModuleType.out_module,
353
+ )
354
+
355
+ def forward(
356
+ self,
357
+ hidden_states: torch.Tensor,
358
+ cu_seqlens: torch.Tensor,
359
+ max_seqlen: int,
360
+ indices: torch.Tensor,
361
+ attn_mask: torch.Tensor,
362
+ ) -> torch.Tensor:
363
+ """Perform self-attention.
364
+
365
+ There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2.
366
+
367
+ The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the
368
+ Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute
369
+ attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not
370
+ sending pad tokens through ffs saves compute.
371
+
372
+ Args:
373
+ hidden_states: (total_nnz, dim)
374
+ cu_seqlens: (batch + 1,)
375
+ max_seqlen: int
376
+ indices: (total_nnz,)
377
+ attn_mask: (batch, max_seqlen)
378
+
379
+ Returns:
380
+ attention: (total_nnz, dim)
381
+ """
382
+ bs, dim = hidden_states.shape
383
+ qkv = self.Wqkv(hidden_states)
384
+
385
+ if self.use_fa2:
386
+ qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
387
+
388
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
389
+ if convert_dtype:
390
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
391
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
392
+ orig_dtype = qkv.dtype
393
+ qkv = qkv.to(torch.bfloat16)
394
+
395
+ attn = flash_attn_varlen_qkvpacked_func(
396
+ qkv,
397
+ cu_seqlens=cu_seqlens,
398
+ max_seqlen=max_seqlen,
399
+ dropout_p=self.p_dropout,
400
+ deterministic=self.deterministic_fa2,
401
+ window_size=self.sliding_window,
402
+ )
403
+ attn = attn.to(orig_dtype) # type: ignore
404
+ else:
405
+ attn = flash_attn_varlen_qkvpacked_func(
406
+ qkv,
407
+ cu_seqlens=cu_seqlens,
408
+ max_seqlen=max_seqlen,
409
+ dropout_p=self.p_dropout,
410
+ deterministic=self.deterministic_fa2,
411
+ window_size=self.sliding_window,
412
+ )
413
+ attn = attn.view(bs, dim)
414
+ else:
415
+ qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
416
+ unpad_bs, seqlen, _ = qkv.shape
417
+
418
+ qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
419
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
420
+ attn = F.scaled_dot_product_attention(
421
+ q,
422
+ k,
423
+ v,
424
+ dropout_p=self.p_dropout,
425
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen)
426
+ if self.use_sdpa_attn_mask
427
+ else None,
428
+ )
429
+ attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
430
+ attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
431
+
432
+ return self.out_drop(self.Wo(attn))
433
+
434
+
435
+ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
436
+ """Computes the output of the multi-headed self parallel attention on a batch of unpadded sequences
437
+
438
+ If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
439
+ If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
440
+ which requires padding and unpadding inputs, adding some overhead.
441
+
442
+ See `forward` method for additional detail.
443
+ """
444
+
445
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
446
+ super().__init__(config=config, layer_id=layer_id)
447
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
448
+ raise ValueError(
449
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
450
+ f"heads ({config.num_attention_heads})"
451
+ )
452
+
453
+ self.num_attention_heads = config.num_attention_heads
454
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
455
+ self.hidden_size = config.hidden_size
456
+ self.p_dropout = config.attention_probs_dropout_prob
457
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
458
+ self.out_drop = (
459
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
460
+ )
461
+ self.use_fa2 = config.use_fa2
462
+ self.deterministic_fa2 = config.deterministic_fa2
463
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
464
+
465
+ if config.global_attn_every_n_layers > 0:
466
+ if config.sliding_window == -1:
467
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
468
+ if layer_id % config.global_attn_every_n_layers != 0:
469
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
470
+ else:
471
+ self.sliding_window = (-1, -1)
472
+ else:
473
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
474
+
475
+ # Warn if defaulting to pytorch because of import issues
476
+ if not IMPL_USE_FLASH2 and self.use_fa2:
477
+ logger.warn_once(
478
+ "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's"
479
+ " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead."
480
+ )
481
+ self.use_fa2 = False
482
+ if not self.use_fa2:
483
+ if not self.use_sdpa_attn_mask:
484
+ logger.warn_once(
485
+ "SDPA attention is being used without an attention mask. Including padding in the "
486
+ " attention calculation may cause differences from the Flash Attention implementation."
487
+ )
488
+ else:
489
+ logger.warn_once(
490
+ "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will"
491
+ " use more memory during the backward pass. Use the FA2 backend for linear memory scaling"
492
+ " with sequence length."
493
+ )
494
+ if self.sliding_window[0] > 0:
495
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
496
+
497
+ def _init_weights(self, reset_params: bool = False):
498
+ init_weights(
499
+ self.config,
500
+ self.Wo,
501
+ layer_dim=self.config.hidden_size,
502
+ layer_id=self.layer_id,
503
+ type_of_module=ModuleType.out_module,
504
+ )
505
+
506
+ def forward(
507
+ self,
508
+ qkv: torch.Tensor,
509
+ cu_seqlens: torch.Tensor,
510
+ max_seqlen: int,
511
+ indices: torch.Tensor,
512
+ attn_mask: torch.Tensor,
513
+ ) -> torch.Tensor:
514
+ """Perform self-attention.
515
+
516
+ There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2.
517
+
518
+ The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the
519
+ Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute
520
+ attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not
521
+ sending pad tokens through ffs saves compute.
522
+
523
+ Args:
524
+ qkv: (total_nnz, 3 * dim)
525
+ cu_seqlens: (batch + 1,)
526
+ max_seqlen: int
527
+ indices: (total_nnz,)
528
+ attn_mask: (batch, max_seqlen)
529
+
530
+ Returns:
531
+ attention: (total_nnz, dim)
532
+ """
533
+ bs = qkv.shape[0]
534
+ dim = self.hidden_size
535
+ if self.use_fa2:
536
+ qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
537
+
538
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
539
+ if convert_dtype:
540
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
541
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
542
+ orig_dtype = qkv.dtype
543
+ qkv = qkv.to(torch.bfloat16)
544
+
545
+ attn = flash_attn_varlen_qkvpacked_func(
546
+ qkv,
547
+ cu_seqlens=cu_seqlens,
548
+ max_seqlen=max_seqlen,
549
+ dropout_p=self.p_dropout,
550
+ deterministic=self.deterministic_fa2,
551
+ window_size=self.sliding_window,
552
+ )
553
+ attn = attn.to(orig_dtype) # type: ignore
554
+ else:
555
+ attn = flash_attn_varlen_qkvpacked_func(
556
+ qkv,
557
+ cu_seqlens=cu_seqlens,
558
+ max_seqlen=max_seqlen,
559
+ dropout_p=self.p_dropout,
560
+ deterministic=self.deterministic_fa2,
561
+ window_size=self.sliding_window,
562
+ )
563
+ attn = attn.view(bs, dim)
564
+ else:
565
+ qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
566
+ unpad_bs, seqlen, _ = qkv.shape
567
+
568
+ qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
569
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
570
+ attn = F.scaled_dot_product_attention(
571
+ q,
572
+ k,
573
+ v,
574
+ dropout_p=self.p_dropout,
575
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen)
576
+ if self.use_sdpa_attn_mask
577
+ else None,
578
+ )
579
+ attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
580
+ attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
581
+
582
+ return self.out_drop(self.Wo(attn.view(bs, dim)))
583
+
584
+
585
+ class FlexBertPaddedAttention(FlexBertAttentionBase):
586
+ """Performs multi-headed self attention on a batch of padded sequences.
587
+
588
+ This module supports two attention implementations:
589
+ 1. Flash Attention 2 (if installed), which improves throughput.
590
+ 2. PyTorch's scaled_dot_product_attention.
591
+
592
+ See `forward` method for additional detail.
593
+ """
594
+
595
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
596
+ super().__init__(config=config, layer_id=layer_id)
597
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
598
+ raise ValueError(
599
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
600
+ f"heads ({config.num_attention_heads})"
601
+ )
602
+
603
+ self.num_attention_heads = config.num_attention_heads
604
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
605
+ self.all_head_size = self.num_attention_heads * self.attn_head_size
606
+ self.p_dropout = config.attention_probs_dropout_prob
607
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias)
608
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
609
+ self.out_drop = (
610
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
611
+ )
612
+ self.use_fa2 = config.use_fa2
613
+ self.deterministic_fa2 = config.deterministic_fa2
614
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
615
+
616
+ if config.global_attn_every_n_layers > 0:
617
+ if config.sliding_window == -1:
618
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
619
+ if layer_id % config.global_attn_every_n_layers != 0:
620
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
621
+ else:
622
+ self.sliding_window = (-1, -1)
623
+ else:
624
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
625
+
626
+ if not IMPL_USE_FLASH2 and self.use_fa2:
627
+ self.use_fa2 = False
628
+ if self.use_fa2 and self.use_sdpa_attn_mask:
629
+ logger.warn_once(
630
+ "Flash Attention 2 does not support attention masks. Use unpadded attention "
631
+ "the equivalent functionality of masking out padding tokens."
632
+ )
633
+ if not self.use_fa2 and self.sliding_window[0] > 0:
634
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
635
+
636
+ def _init_weights(self, reset_params: bool = False):
637
+ init_weights(
638
+ self.config,
639
+ self.Wqkv,
640
+ layer_dim=self.config.hidden_size,
641
+ layer_id=None,
642
+ type_of_module=ModuleType.in_module,
643
+ )
644
+ init_weights(
645
+ self.config,
646
+ self.Wo,
647
+ layer_dim=self.config.hidden_size,
648
+ layer_id=self.layer_id,
649
+ type_of_module=ModuleType.out_module,
650
+ )
651
+
652
+ def forward(
653
+ self,
654
+ hidden_states: torch.Tensor,
655
+ attn_mask: Optional[torch.Tensor] = None,
656
+ ) -> torch.Tensor:
657
+ """Perform self-attention.
658
+
659
+ There are two attention implementations supported:
660
+ Flash Attention 2 and PyTorch's scaled_dot_product_attention.
661
+
662
+ Args:
663
+ hidden_states: (batch, seqlen, dim)
664
+ attn_mask: (batch, seqlen)
665
+
666
+ Returns:
667
+ attention: (batch, seqlen, dim)
668
+ """
669
+ bs, seqlen, dim = hidden_states.shape
670
+ qkv = self.Wqkv(hidden_states)
671
+
672
+ if self.use_fa2:
673
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
674
+
675
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
676
+ if convert_dtype:
677
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
678
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
679
+ orig_dtype = qkv.dtype
680
+ qkv = qkv.to(torch.bfloat16)
681
+
682
+ attn = flash_attn_qkvpacked_func(
683
+ qkv,
684
+ dropout_p=self.p_dropout,
685
+ deterministic=self.deterministic_fa2,
686
+ window_size=self.sliding_window,
687
+ )
688
+ attn = attn.to(orig_dtype) # type: ignore
689
+ else:
690
+ attn = flash_attn_qkvpacked_func(
691
+ qkv,
692
+ dropout_p=self.p_dropout,
693
+ deterministic=self.deterministic_fa2,
694
+ window_size=self.sliding_window,
695
+ )
696
+ else:
697
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
698
+
699
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2)
700
+ attn = F.scaled_dot_product_attention(
701
+ q,
702
+ k,
703
+ v,
704
+ dropout_p=self.p_dropout,
705
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen)
706
+ if self.use_sdpa_attn_mask
707
+ else None,
708
+ ).transpose(1, 2)
709
+
710
+ attn = attn.view(bs, seqlen, dim)
711
+ return self.out_drop(self.Wo(attn))
712
+
713
+
714
+ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
715
+ """Performs multi-headed self attention on a batch of unpadded sequences.
716
+
717
+ If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
718
+ If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
719
+ which requires padding and unpadding inputs, adding some overhead.
720
+
721
+ See `forward` method for additional details.
722
+ """
723
+
724
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
725
+ super().__init__(config=config, layer_id=layer_id)
726
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
727
+ raise ValueError(
728
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
729
+ f"heads ({config.num_attention_heads})"
730
+ )
731
+
732
+ self.num_attention_heads = config.num_attention_heads
733
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
734
+ self.all_head_size = self.num_attention_heads * self.attn_head_size
735
+ self.p_dropout = config.attention_probs_dropout_prob
736
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias)
737
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
738
+ self.out_drop = (
739
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
740
+ )
741
+
742
+ if config.global_attn_every_n_layers > 0:
743
+ if config.sliding_window == -1:
744
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
745
+ if layer_id % config.global_attn_every_n_layers != 0:
746
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
747
+ else:
748
+ self.sliding_window = (-1, -1)
749
+ else:
750
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
751
+
752
+ if config.rotary_emb_dim is None:
753
+ config.rotary_emb_dim = self.attn_head_size
754
+
755
+ rotary_base = config.rotary_emb_base
756
+ rotary_dim = config.rotary_emb_dim
757
+ if self.sliding_window != (-1, -1):
758
+ if config.local_attn_rotary_emb_base != -1:
759
+ rotary_base = config.local_attn_rotary_emb_base
760
+ if config.local_attn_rotary_emb_dim is not None:
761
+ rotary_dim = config.local_attn_rotary_emb_dim
762
+
763
+ assert UnpaddedRotaryEmbedding is not None, "rotary_emb is not installed"
764
+ self.rotary_emb = UnpaddedRotaryEmbedding(
765
+ dim=rotary_dim,
766
+ base=rotary_base,
767
+ scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
768
+ interleaved=config.rotary_emb_interleaved,
769
+ )
770
+
771
+ self.use_fa2 = config.use_fa2
772
+ # flash attention 3 only supports global attention
773
+ self.use_fa3 = config.use_fa2 and self.sliding_window == (-1, -1) and IMPL_USE_FLASH3
774
+ self.deterministic_fa2 = config.deterministic_fa2
775
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
776
+
777
+ # Warn if defaulting to pytorch because of import issues
778
+ if not IMPL_USE_FLASH2 and self.use_fa2:
779
+ logger.warn_once(
780
+ "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's"
781
+ " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead."
782
+ )
783
+ self.use_fa2 = False
784
+ if not self.use_fa2:
785
+ if not self.use_sdpa_attn_mask:
786
+ logger.warn_once(
787
+ "SDPA attention is being used without an attention mask. Including padding in the "
788
+ " attention calculation may cause differences from the Flash Attention implementation."
789
+ )
790
+ else:
791
+ logger.warn_once(
792
+ "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will"
793
+ " use more memory during the backward pass. Use the FA2 backend for linear memory scaling"
794
+ " with sequence length."
795
+ )
796
+ if self.sliding_window[0] > 0:
797
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
798
+
799
+ def _init_weights(self, reset_params: bool = False):
800
+ init_weights(
801
+ self.config,
802
+ self.Wqkv,
803
+ layer_dim=self.config.hidden_size,
804
+ layer_id=None,
805
+ type_of_module=ModuleType.in_module,
806
+ )
807
+ init_weights(
808
+ self.config,
809
+ self.Wo,
810
+ layer_dim=self.config.hidden_size,
811
+ layer_id=self.layer_id,
812
+ type_of_module=ModuleType.out_module,
813
+ )
814
+
815
+ def forward(
816
+ self,
817
+ hidden_states: torch.Tensor,
818
+ cu_seqlens: torch.Tensor,
819
+ max_seqlen: int,
820
+ indices: torch.Tensor,
821
+ attn_mask: torch.Tensor,
822
+ ) -> torch.Tensor:
823
+ """Perform self-attention.
824
+
825
+ There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2.
826
+
827
+ The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the
828
+ Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute
829
+ attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not
830
+ sending pad tokens through ffs saves compute.
831
+
832
+ Args:
833
+ hidden_states: (total_nnz, dim)
834
+ cu_seqlens: (batch + 1,)
835
+ max_seqlen: int
836
+ indices: (total_nnz,)
837
+ attn_mask: (batch, max_seqlen)
838
+
839
+ Returns:
840
+ attention: (total_nnz, dim)
841
+ """
842
+ bs, dim = hidden_states.shape
843
+ qkv = self.Wqkv(hidden_states)
844
+
845
+ # only needed for inference when we have KV cache
846
+ seqlen_offset = 0
847
+
848
+ # (total_seqlen, 3, nheads, headdim)
849
+ qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
850
+ qkv = self.rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, seqlen_offset=seqlen_offset)
851
+
852
+ if self.use_fa3:
853
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
854
+ if convert_dtype:
855
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
856
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
857
+ orig_dtype = qkv.dtype
858
+ qkv = qkv.to(torch.bfloat16)
859
+ q, k, v = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size).unbind(dim=1)
860
+
861
+ attn, _ = flash_attn_varlen_func(
862
+ q=q,
863
+ k=k,
864
+ v=v,
865
+ cu_seqlens_q=cu_seqlens,
866
+ cu_seqlens_k=cu_seqlens,
867
+ max_seqlen_q=max_seqlen,
868
+ max_seqlen_k=max_seqlen,
869
+ deterministic=self.deterministic_fa2,
870
+ )
871
+ attn = attn.to(orig_dtype) # type: ignore
872
+ else:
873
+ q, k, v = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size).unbind(dim=1)
874
+ attn, _ = flash_attn_varlen_func(
875
+ q=q,
876
+ k=k,
877
+ v=v,
878
+ cu_seqlens_q=cu_seqlens,
879
+ cu_seqlens_k=cu_seqlens,
880
+ max_seqlen_q=max_seqlen,
881
+ max_seqlen_k=max_seqlen,
882
+ deterministic=self.deterministic_fa2,
883
+ )
884
+ attn = attn.view(bs, dim)
885
+ elif self.use_fa2:
886
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
887
+ if convert_dtype:
888
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
889
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
890
+ orig_dtype = qkv.dtype
891
+ qkv = qkv.to(torch.bfloat16)
892
+
893
+ attn = flash_attn_varlen_qkvpacked_func(
894
+ qkv,
895
+ cu_seqlens=cu_seqlens,
896
+ max_seqlen=max_seqlen,
897
+ dropout_p=self.p_dropout,
898
+ deterministic=self.deterministic_fa2,
899
+ window_size=self.sliding_window,
900
+ )
901
+ attn = attn.to(orig_dtype) # type: ignore
902
+ else:
903
+ attn = flash_attn_varlen_qkvpacked_func(
904
+ qkv,
905
+ cu_seqlens=cu_seqlens,
906
+ max_seqlen=max_seqlen,
907
+ dropout_p=self.p_dropout,
908
+ deterministic=self.deterministic_fa2,
909
+ window_size=self.sliding_window,
910
+ )
911
+ attn = attn.view(bs, dim)
912
+ else:
913
+ qkv = bert_padding.pad_input(
914
+ qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
915
+ ) # batch, max_seqlen, thd
916
+ unpad_bs, seqlen, *_ = qkv.shape
917
+
918
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
919
+ attn = F.scaled_dot_product_attention(
920
+ q,
921
+ k,
922
+ v,
923
+ dropout_p=self.p_dropout,
924
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen)
925
+ if self.use_sdpa_attn_mask
926
+ else None,
927
+ )
928
+ attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
929
+ attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
930
+
931
+ return self.out_drop(self.Wo(attn))
932
+
933
+
934
+ class FlexBertPaddedRopeAttention(FlexBertAttentionBase):
935
+ """Performs multi-headed self attention on a batch of padded sequences.
936
+
937
+ This module supports two attention implementations:
938
+ 1. Flash Attention 2 (if installed), which improves throughput.
939
+ 2. PyTorch's scaled_dot_product_attention.
940
+
941
+ See `forward` method for additional details.
942
+ """
943
+
944
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
945
+ super().__init__(config=config, layer_id=layer_id)
946
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
947
+ raise ValueError(
948
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
949
+ f"heads ({config.num_attention_heads})"
950
+ )
951
+
952
+ self.num_attention_heads = config.num_attention_heads
953
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
954
+ self.all_head_size = self.num_attention_heads * self.attn_head_size
955
+ self.p_dropout = config.attention_probs_dropout_prob
956
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias)
957
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
958
+ self.out_drop = (
959
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
960
+ )
961
+
962
+ self.use_fa2 = config.use_fa2
963
+ self.deterministic_fa2 = config.deterministic_fa2
964
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
965
+
966
+ if config.global_attn_every_n_layers > 0:
967
+ if config.sliding_window == -1:
968
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
969
+ if layer_id % config.global_attn_every_n_layers != 0:
970
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
971
+ else:
972
+ self.sliding_window = (-1, -1)
973
+ else:
974
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
975
+
976
+ if config.rotary_emb_dim is None:
977
+ config.rotary_emb_dim = self.attn_head_size
978
+
979
+ rotary_base = config.rotary_emb_base
980
+ rotary_dim = config.rotary_emb_dim
981
+ if self.sliding_window != (-1, -1):
982
+ if config.local_attn_rotary_emb_base != -1:
983
+ rotary_base = config.local_attn_rotary_emb_base
984
+ if config.local_attn_rotary_emb_dim is not None:
985
+ rotary_dim = config.local_attn_rotary_emb_dim
986
+
987
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
988
+ self.rotary_emb = RotaryEmbedding(
989
+ dim=rotary_dim,
990
+ base=rotary_base,
991
+ scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
992
+ interleaved=config.rotary_emb_interleaved,
993
+ )
994
+
995
+ if not IMPL_USE_FLASH2 and self.use_fa2:
996
+ self.use_fa2 = False
997
+ if self.use_fa2 and self.use_sdpa_attn_mask:
998
+ logger.warn_once(
999
+ "Flash Attention 2 does not support attention masks. Use unpadded attention "
1000
+ "the equivalent functionality of masking out padding tokens."
1001
+ )
1002
+ if not self.use_fa2 and self.sliding_window[0] > 0:
1003
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
1004
+
1005
+ def _init_weights(self, reset_params: bool = False):
1006
+ init_weights(
1007
+ self.config,
1008
+ self.Wqkv,
1009
+ layer_dim=self.config.hidden_size,
1010
+ layer_id=None,
1011
+ type_of_module=ModuleType.in_module,
1012
+ )
1013
+ init_weights(
1014
+ self.config,
1015
+ self.Wo,
1016
+ layer_dim=self.config.hidden_size,
1017
+ layer_id=self.layer_id,
1018
+ type_of_module=ModuleType.out_module,
1019
+ )
1020
+
1021
+ def forward(
1022
+ self,
1023
+ hidden_states: torch.Tensor,
1024
+ attn_mask: Optional[torch.Tensor] = None,
1025
+ ) -> torch.Tensor:
1026
+ """Perform self-attention.
1027
+
1028
+ There are two attention implementations supported:
1029
+ Flash Attention 2 and PyTorch's scaled_dot_product_attention.
1030
+
1031
+ Args:
1032
+ hidden_states: (batch, seqlen, dim)
1033
+ attn_mask: (batch, seqlen)
1034
+
1035
+ Returns:
1036
+ attention: (batch, seqlen, dim)
1037
+ """
1038
+ bs, seqlen, dim = hidden_states.shape
1039
+ qkv = self.Wqkv(hidden_states)
1040
+
1041
+ seqlen_offset = 0
1042
+
1043
+ # Reshape to (batch, seqlen, 3, nheads, headdim)
1044
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
1045
+
1046
+ if IMPL_USE_FLASH2:
1047
+ # Apply RoPE
1048
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1049
+
1050
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
1051
+ if convert_dtype:
1052
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
1053
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
1054
+ orig_dtype = qkv.dtype
1055
+ qkv = qkv.to(torch.bfloat16)
1056
+
1057
+ attn = flash_attn_qkvpacked_func(
1058
+ qkv,
1059
+ dropout_p=self.p_dropout,
1060
+ deterministic=self.deterministic_fa2,
1061
+ window_size=self.sliding_window,
1062
+ )
1063
+ attn = attn.to(orig_dtype) # type: ignore
1064
+ else:
1065
+ attn = flash_attn_qkvpacked_func(
1066
+ qkv,
1067
+ dropout_p=self.p_dropout,
1068
+ deterministic=self.deterministic_fa2,
1069
+ window_size=self.sliding_window,
1070
+ )
1071
+ else:
1072
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1073
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2)
1074
+ attn = F.scaled_dot_product_attention(
1075
+ q,
1076
+ k,
1077
+ v,
1078
+ dropout_p=self.p_dropout,
1079
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen)
1080
+ if self.use_sdpa_attn_mask
1081
+ else None,
1082
+ ).transpose(1, 2)
1083
+
1084
+ attn = attn.view(bs, seqlen, dim)
1085
+ return self.out_drop(self.Wo(attn))
1086
+
1087
+
1088
+ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
1089
+ """Performs multi-headed self attention on a batch of unpadded sequences.
1090
+
1091
+ If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
1092
+ If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
1093
+ which requires padding and unpadding inputs, adding some overhead.
1094
+
1095
+ See `forward` method for additional details.
1096
+ """
1097
+
1098
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
1099
+ super().__init__(config=config, layer_id=layer_id)
1100
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
1101
+ raise ValueError(
1102
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
1103
+ f"heads ({config.num_attention_heads})"
1104
+ )
1105
+
1106
+ self.num_attention_heads = config.num_attention_heads
1107
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1108
+ self.hidden_size = config.hidden_size
1109
+ self.p_dropout = config.attention_probs_dropout_prob
1110
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
1111
+ self.out_drop = (
1112
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
1113
+ )
1114
+
1115
+ if config.global_attn_every_n_layers > 0:
1116
+ if config.sliding_window == -1:
1117
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
1118
+ if layer_id % config.global_attn_every_n_layers != 0:
1119
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1120
+ else:
1121
+ self.sliding_window = (-1, -1)
1122
+ else:
1123
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1124
+
1125
+ if config.rotary_emb_dim is None:
1126
+ config.rotary_emb_dim = self.attn_head_size
1127
+
1128
+ rotary_base = config.rotary_emb_base
1129
+ rotary_dim = config.rotary_emb_dim
1130
+ if self.sliding_window != (-1, -1):
1131
+ if config.local_attn_rotary_emb_base != -1:
1132
+ rotary_base = config.local_attn_rotary_emb_base
1133
+ if config.local_attn_rotary_emb_dim is not None:
1134
+ rotary_dim = config.local_attn_rotary_emb_dim
1135
+
1136
+ assert UnpaddedRotaryEmbedding is not None, "rotary_emb is not installed"
1137
+ self.rotary_emb = UnpaddedRotaryEmbedding(
1138
+ dim=rotary_dim,
1139
+ base=rotary_base,
1140
+ scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
1141
+ interleaved=config.rotary_emb_interleaved,
1142
+ )
1143
+
1144
+ self.use_fa2 = config.use_fa2
1145
+ self.deterministic_fa2 = config.deterministic_fa2
1146
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
1147
+
1148
+ # Warn if defaulting to pytorch because of import issues
1149
+ if not IMPL_USE_FLASH2 and self.use_fa2:
1150
+ logger.warn_once(
1151
+ "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's"
1152
+ " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead."
1153
+ )
1154
+ self.use_fa2 = False
1155
+ if not self.use_fa2:
1156
+ if not self.use_sdpa_attn_mask:
1157
+ logger.warn_once(
1158
+ "SDPA attention is being used without an attention mask. Including padding in the "
1159
+ " attention calculation may cause differences from the Flash Attention implementation."
1160
+ )
1161
+ else:
1162
+ logger.warn_once(
1163
+ "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will"
1164
+ " use more memory during the backward pass. Use the FA2 backend for linear memory scaling"
1165
+ " with sequence length."
1166
+ )
1167
+ if self.sliding_window[0] > 0:
1168
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
1169
+
1170
+ def _init_weights(self, reset_params: bool = False):
1171
+ init_weights(
1172
+ self.config,
1173
+ self.Wo,
1174
+ layer_dim=self.config.hidden_size,
1175
+ layer_id=self.layer_id,
1176
+ type_of_module=ModuleType.out_module,
1177
+ )
1178
+
1179
+ def forward(
1180
+ self,
1181
+ qkv: torch.Tensor,
1182
+ cu_seqlens: torch.Tensor,
1183
+ max_seqlen: int,
1184
+ indices: torch.Tensor,
1185
+ attn_mask: torch.Tensor,
1186
+ ) -> torch.Tensor:
1187
+ """Perform self-attention.
1188
+
1189
+ There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2.
1190
+
1191
+ The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the
1192
+ Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute
1193
+ attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not
1194
+ sending pad tokens through ffs saves compute.
1195
+
1196
+ Args:
1197
+ qkv: (total_nnz, 3 * dim)
1198
+ cu_seqlens: (batch + 1,)
1199
+ max_seqlen: int
1200
+ indices: (total_nnz,)
1201
+ attn_mask: (batch, max_seqlen)
1202
+
1203
+ Returns:
1204
+ attention: (total_nnz, dim)
1205
+ """
1206
+ bs = qkv.shape[0]
1207
+ dim = self.hidden_size
1208
+
1209
+ # only needed for inference when we have KV cache
1210
+ seqlen_offset = 0
1211
+
1212
+ # (total_seqlen, 3, nheads, headdim)
1213
+ qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
1214
+ qkv = self.rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, seqlen_offset=seqlen_offset)
1215
+
1216
+ if self.use_fa2:
1217
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
1218
+ if convert_dtype:
1219
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
1220
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
1221
+ orig_dtype = qkv.dtype
1222
+ qkv = qkv.to(torch.bfloat16)
1223
+
1224
+ attn = flash_attn_varlen_qkvpacked_func(
1225
+ qkv,
1226
+ cu_seqlens=cu_seqlens,
1227
+ max_seqlen=max_seqlen,
1228
+ dropout_p=self.p_dropout,
1229
+ deterministic=self.deterministic_fa2,
1230
+ window_size=self.sliding_window,
1231
+ )
1232
+ attn = attn.to(orig_dtype) # type: ignore
1233
+ else:
1234
+ attn = flash_attn_varlen_qkvpacked_func(
1235
+ qkv,
1236
+ cu_seqlens=cu_seqlens,
1237
+ max_seqlen=max_seqlen,
1238
+ dropout_p=self.p_dropout,
1239
+ deterministic=self.deterministic_fa2,
1240
+ window_size=self.sliding_window,
1241
+ )
1242
+ attn = attn.view(bs, dim)
1243
+ else:
1244
+ qkv = bert_padding.pad_input(
1245
+ qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
1246
+ ) # batch, max_seqlen, thd
1247
+ unpad_bs, seqlen, *_ = qkv.shape
1248
+
1249
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
1250
+ attn = F.scaled_dot_product_attention(
1251
+ q,
1252
+ k,
1253
+ v,
1254
+ dropout_p=self.p_dropout,
1255
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen)
1256
+ if self.use_sdpa_attn_mask
1257
+ else None,
1258
+ )
1259
+ attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
1260
+ attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
1261
+
1262
+ return self.out_drop(self.Wo(attn))
1263
+
1264
+
1265
+ class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase):
1266
+ """Performs multi-headed self attention on a batch of padded sequences.
1267
+
1268
+ This module supports two attention implementations:
1269
+ 1. Flash Attention 2 (if installed), which improves throughput.
1270
+ 2. PyTorch's scaled_dot_product_attention.
1271
+
1272
+ See `forward` method for additional details.
1273
+ """
1274
+
1275
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
1276
+ super().__init__(config=config, layer_id=layer_id)
1277
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
1278
+ raise ValueError(
1279
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
1280
+ f"heads ({config.num_attention_heads})"
1281
+ )
1282
+
1283
+ self.num_attention_heads = config.num_attention_heads
1284
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1285
+ self.hidden_size = config.hidden_size
1286
+ self.p_dropout = config.attention_probs_dropout_prob
1287
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
1288
+ self.out_drop = (
1289
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
1290
+ )
1291
+
1292
+ self.use_fa2 = config.use_fa2
1293
+ self.deterministic_fa2 = config.deterministic_fa2
1294
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
1295
+ if not IMPL_USE_FLASH2 and self.use_fa2:
1296
+ self.use_fa2 = False
1297
+
1298
+ if config.global_attn_every_n_layers > 0:
1299
+ if config.sliding_window == -1:
1300
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
1301
+ if layer_id % config.global_attn_every_n_layers != 0:
1302
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1303
+ else:
1304
+ self.sliding_window = (-1, -1)
1305
+ else:
1306
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1307
+
1308
+ if config.rotary_emb_dim is None:
1309
+ config.rotary_emb_dim = self.attn_head_size
1310
+
1311
+ rotary_base = config.rotary_emb_base
1312
+ rotary_dim = config.rotary_emb_dim
1313
+ if self.sliding_window != (-1, -1):
1314
+ if config.local_attn_rotary_emb_base != -1:
1315
+ rotary_base = config.local_attn_rotary_emb_base
1316
+ if config.local_attn_rotary_emb_dim is not None:
1317
+ rotary_dim = config.local_attn_rotary_emb_dim
1318
+
1319
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
1320
+ self.rotary_emb = RotaryEmbedding(
1321
+ dim=rotary_dim,
1322
+ base=rotary_base,
1323
+ scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
1324
+ interleaved=config.rotary_emb_interleaved,
1325
+ )
1326
+
1327
+ if not IMPL_USE_FLASH2 and self.use_fa2:
1328
+ self.use_fa2 = False
1329
+ if self.use_fa2 and self.use_sdpa_attn_mask:
1330
+ logger.warn_once(
1331
+ "Flash Attention 2 does not support attention masks. Use unpadded attention "
1332
+ "the equivalent functionality of masking out padding tokens."
1333
+ )
1334
+ if not self.use_fa2 and self.sliding_window[0] > 0:
1335
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
1336
+
1337
+ def _init_weights(self, reset_params: bool = False):
1338
+ init_weights(
1339
+ self.config,
1340
+ self.Wo,
1341
+ layer_dim=self.config.hidden_size,
1342
+ layer_id=self.layer_id,
1343
+ type_of_module=ModuleType.out_module,
1344
+ )
1345
+
1346
+ def forward(
1347
+ self,
1348
+ qkv: torch.Tensor,
1349
+ attn_mask: Optional[torch.Tensor] = None,
1350
+ ) -> torch.Tensor:
1351
+ """Perform self-attention.
1352
+
1353
+ There are two attention implementations supported:
1354
+ Flash Attention 2 and PyTorch's scaled_dot_product_attention.
1355
+
1356
+ Args:
1357
+ qkv: (batch, seqlen, 3 * dim)
1358
+ attn_mask: (batch, seqlen)
1359
+
1360
+ Returns:
1361
+ attention: (batch, seqlen, dim)
1362
+ """
1363
+ bs, seqlen, _ = qkv.shape
1364
+ dim = self.hidden_size
1365
+
1366
+ seqlen_offset = 0
1367
+
1368
+ # Reshape to (batch, seqlen, 3, nheads, headdim)
1369
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
1370
+
1371
+ if self.use_fa2:
1372
+ # Apply RoPE
1373
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1374
+
1375
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
1376
+ if convert_dtype:
1377
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
1378
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
1379
+ orig_dtype = qkv.dtype
1380
+ qkv = qkv.to(torch.bfloat16)
1381
+
1382
+ attn = flash_attn_qkvpacked_func(
1383
+ qkv,
1384
+ dropout_p=self.p_dropout,
1385
+ deterministic=self.deterministic_fa2,
1386
+ window_size=self.sliding_window,
1387
+ )
1388
+ attn = attn.to(orig_dtype) # type: ignore
1389
+ else:
1390
+ attn = flash_attn_qkvpacked_func(
1391
+ qkv,
1392
+ dropout_p=self.p_dropout,
1393
+ deterministic=self.deterministic_fa2,
1394
+ window_size=self.sliding_window,
1395
+ )
1396
+ else:
1397
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1398
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2)
1399
+ attn = F.scaled_dot_product_attention(
1400
+ q,
1401
+ k,
1402
+ v,
1403
+ dropout_p=self.p_dropout,
1404
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen)
1405
+ if self.use_sdpa_attn_mask
1406
+ else None,
1407
+ ).transpose(1, 2)
1408
+
1409
+ attn = attn.view(bs, seqlen, dim)
1410
+ return self.out_drop(self.Wo(attn))
1411
+
1412
+
1413
+ class FlexBertPaddedParallelAttention(FlexBertAttentionBase):
1414
+ """Performs multi-headed self attention on a batch of padded sequences.
1415
+
1416
+ This module supports two attention implementations:
1417
+ 1. Flash Attention 2 (if installed), which improves throughput.
1418
+ 2. PyTorch's scaled_dot_product_attention.
1419
+
1420
+ See `forward` method for additional detail.
1421
+ """
1422
+
1423
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
1424
+ super().__init__(config=config, layer_id=layer_id)
1425
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
1426
+ raise ValueError(
1427
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
1428
+ f"heads ({config.num_attention_heads})"
1429
+ )
1430
+
1431
+ self.num_attention_heads = config.num_attention_heads
1432
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1433
+ self.hidden_size = config.hidden_size
1434
+ self.p_dropout = config.attention_probs_dropout_prob
1435
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
1436
+ self.out_drop = (
1437
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
1438
+ )
1439
+ self.use_fa2 = config.use_fa2
1440
+ self.deterministic_fa2 = config.deterministic_fa2
1441
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
1442
+
1443
+ if config.global_attn_every_n_layers > 0:
1444
+ if config.sliding_window == -1:
1445
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
1446
+ if layer_id % config.global_attn_every_n_layers != 0:
1447
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1448
+ else:
1449
+ self.sliding_window = (-1, -1)
1450
+ else:
1451
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1452
+
1453
+ if not IMPL_USE_FLASH2 and self.use_fa2:
1454
+ self.use_fa2 = False
1455
+ if self.use_fa2 and self.use_sdpa_attn_mask:
1456
+ logger.warn_once(
1457
+ "Flash Attention 2 does not support attention masks. Use unpadded attention "
1458
+ "the equivalent functionality of masking out padding tokens."
1459
+ )
1460
+ if not self.use_fa2 and self.sliding_window[0] > 0:
1461
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
1462
+
1463
+ def _init_weights(self, reset_params: bool = False):
1464
+ init_weights(
1465
+ self.config,
1466
+ self.Wo,
1467
+ layer_dim=self.config.hidden_size,
1468
+ layer_id=self.layer_id,
1469
+ type_of_module=ModuleType.out_module,
1470
+ )
1471
+
1472
+ def forward(
1473
+ self,
1474
+ qkv: torch.Tensor,
1475
+ attn_mask: Optional[torch.Tensor] = None,
1476
+ ) -> torch.Tensor:
1477
+ """Perform self-attention.
1478
+
1479
+ There are two attention implementations supported:
1480
+ Flash Attention 2 and PyTorch's scaled_dot_product_attention.
1481
+
1482
+ Args:
1483
+ qkv: (batch, seqlen, 3 * dim)
1484
+ attn_mask: (batch, seqlen)
1485
+
1486
+ Returns:
1487
+ attention: (batch, seqlen, dim)
1488
+ """
1489
+ bs, seqlen, _ = qkv.shape
1490
+ dim = self.hidden_size
1491
+
1492
+ if self.use_fa2:
1493
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
1494
+
1495
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
1496
+ if convert_dtype:
1497
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
1498
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
1499
+ orig_dtype = qkv.dtype
1500
+ qkv = qkv.to(torch.bfloat16)
1501
+
1502
+ attn = flash_attn_qkvpacked_func(
1503
+ qkv,
1504
+ dropout_p=self.p_dropout,
1505
+ deterministic=self.deterministic_fa2,
1506
+ window_size=self.sliding_window,
1507
+ )
1508
+ attn = attn.to(orig_dtype) # type: ignore
1509
+ else:
1510
+ attn = flash_attn_qkvpacked_func(
1511
+ qkv,
1512
+ dropout_p=self.p_dropout,
1513
+ deterministic=self.deterministic_fa2,
1514
+ window_size=self.sliding_window,
1515
+ )
1516
+ else:
1517
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
1518
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
1519
+ attn = F.scaled_dot_product_attention(
1520
+ q,
1521
+ k,
1522
+ v,
1523
+ dropout_p=self.p_dropout,
1524
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen)
1525
+ if self.use_sdpa_attn_mask
1526
+ else None,
1527
+ ).transpose(1, 2)
1528
+
1529
+ attn = attn.view(bs, seqlen, dim)
1530
+ return self.out_drop(self.Wo(attn))
1531
+
1532
+
1533
+ ATTN2CLS = {
1534
+ "unpadded_base": FlexBertUnpadAttention,
1535
+ "padded_base": FlexBertPaddedAttention,
1536
+ "unpadded_parallel": FlexBertUnpadParallelAttention,
1537
+ "padded_parallel": FlexBertPaddedParallelAttention,
1538
+ "unpadded_rope": FlexBertUnpadRopeAttention,
1539
+ "padded_rope": FlexBertPaddedRopeAttention,
1540
+ "unpadded_rope_parallel": FlexBertUnpadRopeParallelAttention,
1541
+ "padded_rope_parallel": FlexBertPaddedRopeParallelAttention,
1542
+ }
1543
+
1544
+
1545
+ def get_attention_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertAttentionBase:
1546
+ try:
1547
+ attention_layer = (
1548
+ config.initial_attention_layer
1549
+ if layer_id < config.num_initial_layers and getattr(config, "initial_attention_layer", None) is not None
1550
+ else config.attention_layer
1551
+ )
1552
+ return ATTN2CLS[maybe_add_padding(config, attention_layer)](config, layer_id=layer_id)
1553
+ except KeyError:
1554
+ if layer_id < config.num_initial_layers and getattr(config, "initial_attention_layer", None) is not None:
1555
+ raise ValueError(
1556
+ f"Invalid attention layer type: {config.initial_attention_layer=}, must be one of {ATTN2CLS.keys()}."
1557
+ f"{config.padding=} will be automatically prepended to `config.attention_layer` if unspecified."
1558
+ )
1559
+ else:
1560
+ raise ValueError(
1561
+ f"Invalid attention layer type: {config.attention_layer=}, must be one of {ATTN2CLS.keys()}. "
1562
+ f"{config.padding=} will be automatically prepended to `config.attention_layer` if unspecified."
1563
+ )
bert_padding.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
5
+ # Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
6
+
7
+ """Helper functions for padding and unpadding batches.
8
+
9
+ These functions are used extensively throughout the Mosaic BERT implementation
10
+ in `bert_layers.py`.
11
+ """
12
+
13
+ from typing import Tuple, cast
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from einops import rearrange, repeat
18
+
19
+
20
+ class IndexFirstAxis(torch.autograd.Function):
21
+ @staticmethod
22
+ def forward(ctx, input: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
23
+ """Get just the values of `input` which are at `indices`.
24
+
25
+ Arguments:
26
+ ctx: the autograd context object
27
+ input: (b, ...) 2+ dimensional tensor
28
+ indices: (num_idx) 1D tensor
29
+ """
30
+ ctx.save_for_backward(indices)
31
+ assert input.ndim >= 2
32
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] # type: ignore
33
+ second_dim = other_shape.numel() # product of sizes of all but first dimension
34
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
35
+ return torch.gather(
36
+ rearrange(input, "b ... -> b (...)"), # (b, ...) -> (b, second_dim)
37
+ 0,
38
+ repeat(indices, "z -> z d", d=second_dim), # (indices,) -> (indices, second_dim)
39
+ ).reshape(-1, *other_shape) # (num_idx, ...)
40
+
41
+ @staticmethod
42
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
43
+ (indices,) = ctx.saved_tensors
44
+ assert grad_output.ndim >= 2
45
+ other_shape = grad_output.shape[1:]
46
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
47
+ grad_input = torch.zeros(
48
+ [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
49
+ )
50
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
51
+ # grad_input[indices] = grad_output
52
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
53
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
54
+
55
+
56
+ index_first_axis = IndexFirstAxis.apply
57
+
58
+
59
+ class IndexPutFirstAxis(torch.autograd.Function):
60
+ @staticmethod
61
+ def forward(ctx, values: torch.Tensor, indices: torch.Tensor, first_axis_dim) -> torch.Tensor:
62
+ ctx.save_for_backward(indices)
63
+ assert indices.ndim == 1
64
+ assert values.ndim >= 2
65
+ output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
66
+ output[indices] = values
67
+ return output
68
+
69
+ @staticmethod
70
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
71
+ (indices,) = ctx.saved_tensors
72
+ grad_values = grad_output[indices]
73
+ return grad_values, None, None
74
+
75
+
76
+ index_put_first_axis = IndexPutFirstAxis.apply
77
+
78
+
79
+ def unpad_input(
80
+ hidden_states: torch.Tensor,
81
+ attention_mask: torch.Tensor,
82
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
83
+ """Remove padding from input sequences.
84
+
85
+ Arguments:
86
+ hidden_states: (batch, seqlen, ...)
87
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
88
+
89
+ Returns:
90
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
91
+ indices: (total_nnz)
92
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
93
+ max_seqlen_in_batch: int ()
94
+ """
95
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
96
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
97
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
98
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
99
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
100
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
101
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
102
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
103
+ # so we write custom forward and backward to make it a bit faster.
104
+ hidden_states = cast(torch.Tensor, index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices))
105
+ return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
106
+
107
+
108
+ def unpad_input_only(
109
+ hidden_states: torch.Tensor,
110
+ attention_mask: torch.Tensor,
111
+ ) -> torch.Tensor:
112
+ """Like unpad_input, but only return the unpadded first tensor.
113
+
114
+ Save a small amount of overhead.
115
+
116
+ Arguments:
117
+ hidden_states: (batch, seqlen, ...)
118
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
119
+
120
+ Returns:
121
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
122
+ """
123
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
124
+ rearranged = rearrange(hidden_states, "b s ... -> (b s) ...")
125
+ return index_first_axis(rearranged, indices) # type: ignore
126
+
127
+
128
+ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
129
+ """Add padding to sequences.
130
+
131
+ Arguments:
132
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
133
+ indices: (total_nnz)
134
+ batch: int batch_size
135
+ seqlen: int max sequence length
136
+
137
+ Returns:
138
+ hidden_states: (batch, seqlen, ...)
139
+ """
140
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
141
+ return rearrange(output, "(b s) ... -> b s ...", b=batch) # type: ignore
configuration_bert.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import warnings
5
+
6
+ from transformers import BertConfig as TransformersBertConfig
7
+
8
+
9
+ class BertConfig(TransformersBertConfig):
10
+ def __init__(
11
+ self,
12
+ alibi_starting_size: int = 512,
13
+ normalization: str = "layernorm",
14
+ attention_probs_dropout_prob: float = 0.0,
15
+ head_pred_act: str = "gelu",
16
+ deterministic_fa2: bool = False,
17
+ allow_embedding_resizing: bool = False,
18
+ **kwargs,
19
+ ):
20
+ """Configuration class for MosaicBert.
21
+
22
+ Args:
23
+ alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to
24
+ create when initializing the model. You should be able to ignore this parameter in most cases.
25
+ Defaults to 512.
26
+ attention_probs_dropout_prob (float): By default, turn off attention dropout in MosaicBERT
27
+ Note that the custom Triton Flash Attention with ALiBi implementation does not support droput.
28
+ However, Flash Attention 2 supports ALiBi and dropout https://github.com/Dao-AILab/flash-attention
29
+ embed_dropout_prob (float): Dropout probability for the embedding layer.
30
+ attn_out_dropout_prob (float): Dropout probability for the attention output layer.
31
+ mlp_dropout_prob (float): Dropout probability for the MLP layer.
32
+ allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size.
33
+ """
34
+ super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
35
+ self.alibi_starting_size = alibi_starting_size
36
+ self.normalization = normalization
37
+ self.head_pred_act = head_pred_act
38
+ self.deterministic_fa2 = deterministic_fa2
39
+ self.allow_embedding_resizing = allow_embedding_resizing
40
+
41
+
42
+ class FlexBertConfig(TransformersBertConfig):
43
+ def __init__(
44
+ self,
45
+ attention_layer: str = "base",
46
+ attention_probs_dropout_prob: float = 0.0,
47
+ attn_out_bias: bool = False,
48
+ attn_out_dropout_prob: float = 0.0,
49
+ attn_qkv_bias: bool = False,
50
+ bert_layer: str = "prenorm",
51
+ decoder_bias: bool = True,
52
+ embed_dropout_prob: float = 0.0,
53
+ embed_norm: bool = True,
54
+ final_norm: bool = False,
55
+ embedding_layer: str = "absolute_pos",
56
+ encoder_layer: str = "base",
57
+ loss_function: str = "cross_entropy",
58
+ loss_kwargs: dict = {},
59
+ mlp_dropout_prob: float = 0.0,
60
+ mlp_in_bias: bool = False,
61
+ mlp_layer: str = "mlp",
62
+ mlp_out_bias: bool = False,
63
+ norm_kwargs: dict = {},
64
+ normalization: str = "rmsnorm",
65
+ padding: str = "unpadded",
66
+ head_class_act: str = "silu",
67
+ head_class_bias: bool = False,
68
+ head_class_dropout: float = 0.0,
69
+ head_class_norm: str = False,
70
+ head_pred_act: str = "silu",
71
+ head_pred_bias: bool = False,
72
+ head_pred_dropout: float = 0.0,
73
+ head_pred_norm: bool = True,
74
+ pooling_type: str = "cls",
75
+ rotary_emb_dim: int | None = None,
76
+ rotary_emb_base: float = 10000.0,
77
+ rotary_emb_scale_base=None,
78
+ rotary_emb_interleaved: bool = False,
79
+ use_fa2: bool = True,
80
+ use_sdpa_attn_mask: bool = False,
81
+ allow_embedding_resizing: bool = False,
82
+ init_method: str = "default",
83
+ init_std: float = 0.02,
84
+ init_cutoff_factor: float = 2.0,
85
+ init_small_embedding: bool = False,
86
+ initial_attention_layer: str | None = None,
87
+ initial_bert_layer: str | None = None,
88
+ initial_mlp_layer: str | None = None,
89
+ num_initial_layers: int = 1,
90
+ skip_first_prenorm: bool = False,
91
+ deterministic_fa2: bool = False,
92
+ sliding_window: int = -1,
93
+ global_attn_every_n_layers: int = -1,
94
+ local_attn_rotary_emb_base: float = -1,
95
+ local_attn_rotary_emb_dim: int | None = None,
96
+ unpad_embeddings: bool = False,
97
+ pad_logits: bool = False,
98
+ compile_model: bool = False,
99
+ masked_prediction: bool = False,
100
+ **kwargs,
101
+ ):
102
+ """
103
+ Args:
104
+ attention_layer (str): Attention layer type.
105
+ attention_probs_dropout_prob (float): Dropout probability for attention probabilities.
106
+ attn_out_bias (bool): use bias in attention output projection.
107
+ attn_out_dropout_prob (float): Dropout probability for attention output.
108
+ attn_qkv_bias (bool): use bias for query, key, value linear layer(s).
109
+ bert_layer (str): BERT layer type.
110
+ decoder_bias (bool): use bias in decoder linear layer.
111
+ embed_dropout_prob (float): Dropout probability for embeddings.
112
+ embed_norm (bool): Normalize embedding output.
113
+ final_norm (bool): Add normalization after the final encoder layer and before head.
114
+ embedding_layer (str): Embedding layer type.
115
+ encoder_layer (str): Encoder layer type.
116
+ loss_function (str): Loss function to use.
117
+ loss_kwargs (dict): Keyword arguments for loss function.
118
+ mlp_dropout_prob (float): Dropout probability for MLP layers.
119
+ mlp_in_bias (bool): Use bias in MLP input linear layer.
120
+ mlp_layer (str): MLP layer type.
121
+ mlp_out_bias (bool): Use bias in MLP output linear layer.
122
+ norm_kwargs (dict): Keyword arguments for normalization layers.
123
+ normalization (str): Normalization type.
124
+ padding (str): Unpad inputs. Best with `use_fa2=True`.
125
+ head_class_act (str): Activation function for classification head.
126
+ head_class_bias (bool): Use bias in classification head linear layer(s).
127
+ head_class_dropout (float): Dropout probability for classification head.
128
+ head_class_norm (str): Normalization type for classification head.
129
+ head_pred_act (str): Activation function for prediction head.
130
+ head_pred_bias (bool): Use bias in prediction head linear layer(s).
131
+ head_pred_dropout (float): Dropout probability for prediction head.
132
+ head_pred_norm (bool): Normalize prediction head output.
133
+ pooling_type (str): Pooling type.
134
+ rotary_emb_dim (int | None): Rotary embedding dimension.
135
+ rotary_emb_base (float): Rotary embedding base.
136
+ rotary_emb_scale_base (float): Rotary embedding scale base.
137
+ rotary_emb_interleaved (bool): Use interleaved rotary embeddings.
138
+ use_fa2 (bool): Use FlashAttention2. Requires flash_attn package.
139
+ use_sdpa_attn_mask (bool): Pass a mask to SDPA. This will prevent SDPA from using the PyTorch FA2 kernel.
140
+ allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size.
141
+ init_method (str): Model layers initialization method.
142
+ init_std (float): Standard deviation for initialization. Used for normal and full_megatron init.
143
+ init_cutoff_factor (float): Cutoff factor for initialization. Used for normal and full_megatron init.
144
+ init_small_embedding (bool): Initialize embeddings with RWKV small init.
145
+ initial_attention_layer (str | None): Replace first `num_initial_layers` attention_layer instance with this layer.
146
+ initial_bert_layer (str | None): Replace first `num_initial_layers` bert_layer instance with this layer.
147
+ initial_mlp_layer (str | None): Replace first `num_initial_layers` mlp_layer instance with this layer.
148
+ num_initial_layers (int): Number of initial layers to set via `initial_attention_layer`, `initial_bert_layer`, and `initial_mlp_layer`.
149
+ skip_first_prenorm (bool): Skip pre-normalization for the first bert layer. Requires `embed_norm=True`.
150
+ deterministic_fa2 (bool): Use Flash Attention 2 deterministic mode. This is slower then the default non-deterministic mode.
151
+ sliding_window (int): Use sliding window attention with window size `n`. -1 to disable. Window size split between the left and right context. Only supports FA2.
152
+ global_attn_every_n_layers (int): Use global attention every `n` layers and sliding window for the rest. -1 to disable.
153
+ local_attn_rotary_emb_base (float): Rotary embedding base for local attention. -1 to disable and use `rotary_emb_base` for all layers.
154
+ local_attn_rotary_emb_dim (int | None): Rotary embedding dimension for local attention. None to disable and use `rotary_emb_dim` for all layers.
155
+ unpad_embeddings (bool): Unpad inputs before the embedding layer.
156
+ pad_logits (bool): Pad logits after the calculating the loss.
157
+ compile_model (bool): Compile the subset of the model which can be compiled.
158
+ masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers
159
+ **kwargs: Additional keyword arguments.
160
+ """
161
+ super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
162
+ self.attention_layer = attention_layer
163
+ self.attn_out_bias = attn_out_bias
164
+ self.attn_out_dropout_prob = attn_out_dropout_prob
165
+ self.attn_qkv_bias = attn_qkv_bias
166
+ self.bert_layer = bert_layer
167
+ self.decoder_bias = decoder_bias
168
+ self.embed_dropout_prob = embed_dropout_prob
169
+ self.embed_norm = embed_norm
170
+ self.final_norm = final_norm
171
+ self.embedding_layer = embedding_layer
172
+ self.encoder_layer = encoder_layer
173
+ self.loss_function = loss_function
174
+ self.loss_kwargs = loss_kwargs
175
+ self.mlp_dropout_prob = mlp_dropout_prob
176
+ self.mlp_in_bias = mlp_in_bias
177
+ self.mlp_layer = mlp_layer
178
+ self.mlp_out_bias = mlp_out_bias
179
+ self.norm_kwargs = norm_kwargs
180
+ self.normalization = normalization
181
+ self.padding = padding
182
+ self.head_class_act = head_class_act
183
+ self.head_class_bias = head_class_bias
184
+ self.head_class_dropout = head_class_dropout
185
+ self.head_class_norm = head_class_norm
186
+ self.head_pred_act = head_pred_act
187
+ self.head_pred_bias = head_pred_bias
188
+ self.head_pred_dropout = head_pred_dropout
189
+ self.head_pred_norm = head_pred_norm
190
+ self.pooling_type = pooling_type
191
+ self.rotary_emb_dim = rotary_emb_dim
192
+ self.rotary_emb_base = rotary_emb_base
193
+ self.rotary_emb_scale_base = rotary_emb_scale_base
194
+ self.rotary_emb_interleaved = rotary_emb_interleaved
195
+ self.use_fa2 = use_fa2
196
+ self.use_sdpa_attn_mask = use_sdpa_attn_mask
197
+ self.allow_embedding_resizing = allow_embedding_resizing
198
+ self.init_method = init_method
199
+ self.init_std = init_std
200
+ self.init_cutoff_factor = init_cutoff_factor
201
+ self.init_small_embedding = init_small_embedding
202
+ self.initial_attention_layer = initial_attention_layer
203
+ self.initial_bert_layer = initial_bert_layer
204
+ self.initial_mlp_layer = initial_mlp_layer
205
+ self.num_initial_layers = num_initial_layers
206
+ self.skip_first_prenorm = skip_first_prenorm
207
+ self.deterministic_fa2 = deterministic_fa2
208
+ self.sliding_window = sliding_window
209
+ self.global_attn_every_n_layers = global_attn_every_n_layers
210
+ self.local_attn_rotary_emb_base = local_attn_rotary_emb_base
211
+ self.local_attn_rotary_emb_dim = local_attn_rotary_emb_dim
212
+ self.unpad_embeddings = unpad_embeddings
213
+ self.pad_logits = pad_logits
214
+ self.compile_model = compile_model
215
+ self.masked_prediction = masked_prediction
216
+
217
+ if loss_kwargs.get("return_z_loss", False):
218
+ if loss_function != "fa_cross_entropy":
219
+ raise ValueError("loss_function must be 'fa_cross_entropy' when return_z_loss is True")
220
+ if loss_kwargs.get("lse_square_scale", 0) <= 0:
221
+ raise ValueError(
222
+ "lse_square_scale must be passed to `loss_kwargs` and must be greater than 0 for z_loss"
223
+ )
224
+ if loss_kwargs.get("inplace_backward", False):
225
+ self.loss_kwargs["inplace_backward"] = False
226
+ warnings.warn("`inplace_backward=True` will cause incorrect metrics. Automatically setting to False.")
227
+
228
+ if global_attn_every_n_layers > 0 and (self.num_hidden_layers - 1) % global_attn_every_n_layers != 0:
229
+ raise ValueError(
230
+ f"{global_attn_every_n_layers=} must be a divisor of one less than {self.num_hidden_layers=}"
231
+ )
232
+
233
+ if self.sliding_window != -1:
234
+ if not self.use_fa2:
235
+ raise ValueError("Sliding window attention is only supported with FlashAttention2")
236
+ if self.sliding_window % 2 != 0 and self.sliding_window % 64 != 0:
237
+ raise ValueError(
238
+ f"Sliding window must be an even number and divisible by 64: {self.sliding_window=} {self.sliding_window % 64} {self.sliding_window % 2}"
239
+ )
240
+ else:
241
+ if self.global_attn_every_n_layers != -1:
242
+ raise ValueError("global_attn_every_n_layers must be -1 when sliding_window is disabled")
243
+ if self.local_attn_rotary_emb_base != -1:
244
+ raise ValueError("local_attn_rotary_emb_base must be -1 when sliding_window is disabled")
245
+ if self.local_attn_rotary_emb_dim is not None:
246
+ raise ValueError("local_attn_rotary_emb_dim must be None when sliding_window is disabled")
247
+
248
+ if self.unpad_embeddings and self.padding != "unpadded":
249
+ warnings.warn(
250
+ "`unpad_embeddings=True` requires `padding='unpadded'`. Automatically setting `padding='unpadded'`."
251
+ )
252
+ self.padding = "unpadded"
253
+ if self.pad_logits and not self.unpad_embeddings:
254
+ raise ValueError("`pad_logits=True` requires `unpad_embeddings=True`")
255
+ if self.unpad_embeddings and self.embedding_layer == "absolute_pos":
256
+ raise ValueError(f"{self.unpad_embeddings=} is incompatible with {self.embedding_layer=}")
257
+
258
+
259
+ PADDING = ["unpadded", "padded"]
260
+
261
+
262
+ def maybe_add_padding(config: FlexBertConfig, config_option: str) -> str:
263
+ if config.padding not in PADDING:
264
+ raise ValueError(f"Invalid padding type: {config.padding}, must be one of {PADDING}")
265
+
266
+ if not any(config_option.startswith(pad + "_") for pad in PADDING):
267
+ config_option = f"{config.padding}_{config_option}"
268
+
269
+ return config_option
embeddings.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2022 MosaicML Examples authors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+
7
+ # Copyright 2023 MosaicML Examples authors
8
+ # SPDX-License-Identifier: Apache-2.0
9
+
10
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
11
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
12
+ # Copyright (c) 2023, Tri Dao.
13
+
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Optional
18
+
19
+ from configuration_bert import FlexBertConfig
20
+ from normalization import get_norm_layer
21
+ from initialization import ModuleType, init_weights
22
+
23
+
24
+ class BertAlibiEmbeddings(nn.Module):
25
+ """Construct the embeddings for words, ignoring position.
26
+
27
+ There are no positional embeddings since we use ALiBi and token_type
28
+ embeddings.
29
+
30
+ This module is modeled after the Hugging Face BERT's
31
+ :class:`~transformers.model.bert.modeling_bert.BertEmbeddings`, but is
32
+ modified as part of Mosaic BERT's ALiBi implementation. The key change is
33
+ that position embeddings are removed. Position information instead comes
34
+ from attention biases that scale linearly with the position distance
35
+ between query and key tokens.
36
+
37
+ This module ignores the `position_ids` input to the `forward` method.
38
+ """
39
+
40
+ def __init__(self, config):
41
+ super().__init__()
42
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
43
+ # ALiBi doesn't use position embeddings
44
+ if getattr(config, "token_type_embeddings", True):
45
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
46
+ self.use_token_type_embeddings = True
47
+ else:
48
+ self.use_token_type_embeddings = False
49
+
50
+ self.LayerNorm = get_norm_layer(config)
51
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
52
+ if self.use_token_type_embeddings:
53
+ self.register_buffer(
54
+ "token_type_ids", torch.zeros(config.max_position_embeddings, dtype=torch.long), persistent=False
55
+ )
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: Optional[torch.LongTensor] = None,
60
+ token_type_ids: Optional[torch.LongTensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ past_key_values_length: int = 0,
64
+ ) -> torch.Tensor:
65
+ if (input_ids is not None) == (inputs_embeds is not None):
66
+ raise ValueError("Must specify either input_ids or input_embeds!")
67
+ if input_ids is not None:
68
+ input_shape = input_ids.size()
69
+ else:
70
+ assert inputs_embeds is not None # just for type checking
71
+ input_shape = inputs_embeds.size()[:-1]
72
+
73
+ seq_length = input_shape[1]
74
+
75
+ if position_ids is None:
76
+ # great! ALiBi
77
+ pass
78
+
79
+ # Setting the token_type_ids to the registered buffer in constructor
80
+ # where it is all zeros, which usually occurs when it's auto-generated;
81
+ # registered buffer helps users when tracing the model without passing
82
+ # token_type_ids, solves issue #5664
83
+ if self.use_token_type_embeddings and token_type_ids is None:
84
+ if hasattr(self, "token_type_ids"):
85
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
86
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
87
+ token_type_ids = buffered_token_type_ids_expanded
88
+ else:
89
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
90
+
91
+ if inputs_embeds is None:
92
+ inputs_embeds = self.word_embeddings(input_ids)
93
+
94
+ if self.use_token_type_embeddings:
95
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
96
+ embeddings = inputs_embeds + token_type_embeddings
97
+ else:
98
+ embeddings = inputs_embeds
99
+
100
+ # no position embeddings! ALiBi
101
+ embeddings = self.LayerNorm(embeddings)
102
+ embeddings = self.dropout(embeddings)
103
+ return embeddings
104
+
105
+
106
+ class FlexBertEmbeddingsBase(nn.Module):
107
+ """A FlexBERT embeddings base class for type hints."""
108
+
109
+ def __init__(self, config: FlexBertConfig):
110
+ super().__init__()
111
+ self.config = config
112
+
113
+ def _init_weights(self, reset_params: bool = False):
114
+ raise NotImplementedError("This is a base class and should not be used directly.")
115
+
116
+ def reset_parameters(self):
117
+ self._init_weights(reset_params=True)
118
+
119
+ def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
120
+ raise NotImplementedError("This is a base class and should not be used directly.")
121
+
122
+
123
+ class FlexBertAbsoluteEmbeddings(FlexBertEmbeddingsBase):
124
+ """Construct the embeddings with absolute positional embeddings."""
125
+
126
+ def __init__(self, config: FlexBertConfig):
127
+ super().__init__(config)
128
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
129
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
130
+
131
+ self.norm = get_norm_layer(config) if config.embed_norm else nn.Identity()
132
+ self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity()
133
+
134
+ self.register_buffer(
135
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
136
+ )
137
+
138
+ def _init_weights(self, reset_params: bool = False):
139
+ init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb)
140
+ init_weights(self.config, self.position_embeddings, type_of_module=ModuleType.emb)
141
+
142
+ if reset_params:
143
+ if self.config.embed_norm:
144
+ self.norm.reset_parameters() # type: ignore
145
+
146
+ def forward(
147
+ self,
148
+ input_ids: torch.LongTensor,
149
+ position_ids: Optional[torch.LongTensor] = None,
150
+ ) -> torch.Tensor:
151
+ if position_ids is None:
152
+ position_ids = self.position_ids[:, 0 : input_ids.shape[1]]
153
+
154
+ embeddings = self.tok_embeddings(input_ids)
155
+ position_embeddings = self.position_embeddings(position_ids)
156
+
157
+ embeddings = self.norm(embeddings + position_embeddings)
158
+ return self.drop(embeddings)
159
+
160
+
161
+ class FlexBertCompiledSansPositionEmbeddings(FlexBertEmbeddingsBase):
162
+ """Construct the embeddings from token embeddings without any positional embeddings."""
163
+
164
+ def __init__(self, config: FlexBertConfig):
165
+ super().__init__(config)
166
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
167
+
168
+ self.norm = get_norm_layer(config, compiled_norm=config.compile_model) if config.embed_norm else nn.Identity()
169
+ self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity()
170
+
171
+ def _init_weights(self, reset_params: bool = False):
172
+ init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb)
173
+
174
+ if reset_params:
175
+ if self.config.embed_norm:
176
+ self.norm.reset_parameters() # type: ignore
177
+
178
+ @torch.compile(dynamic=True)
179
+ def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
180
+ return self.drop(self.norm(self.tok_embeddings(input_ids)))
181
+
182
+
183
+ class FlexBertSansPositionEmbeddings(FlexBertEmbeddingsBase):
184
+ """Construct the embeddings from token embeddings without any positional embeddings."""
185
+
186
+ def __init__(self, config: FlexBertConfig):
187
+ super().__init__(config)
188
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
189
+
190
+ self.norm = get_norm_layer(config) if config.embed_norm else nn.Identity()
191
+ self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity()
192
+
193
+ def _init_weights(self, reset_params: bool = False):
194
+ init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb)
195
+
196
+ if reset_params:
197
+ if self.config.embed_norm:
198
+ self.norm.reset_parameters() # type: ignore
199
+
200
+ def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
201
+ return self.drop(self.norm(self.tok_embeddings(input_ids)))
202
+
203
+
204
+ EBB2CLS = {
205
+ "absolute_pos": FlexBertAbsoluteEmbeddings,
206
+ "sans_pos": FlexBertSansPositionEmbeddings,
207
+ }
208
+
209
+
210
+ def get_embedding_layer(config: FlexBertConfig) -> FlexBertEmbeddingsBase:
211
+ try:
212
+ if config.compile_model and config.embedding_layer == "sans_pos":
213
+ return FlexBertCompiledSansPositionEmbeddings(config)
214
+ elif config.compile_model:
215
+ raise ValueError(f"{config.compile_model=} only supports sans_pos embeddings.")
216
+ return EBB2CLS[config.embedding_layer](config)
217
+ except KeyError:
218
+ raise ValueError(f"Invalid embeddings layer type: {config.embedding_layer=}, must be one of {EBB2CLS.keys()}.")
initialization.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Copyright 2023 OLMo Authors
5
+ # License: Apache-2.0
6
+
7
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
8
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
9
+ # License: Apache-2.0
10
+
11
+ import math
12
+ from typing import Optional, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from utils import StrEnum
18
+
19
+ from configuration_bert import FlexBertConfig
20
+ from normalization import RMSNorm
21
+
22
+ __all__ = ["init_weights", "ModuleType", "InitFnType"]
23
+
24
+
25
+ class InitFnType(StrEnum):
26
+ mitchell = "mitchell"
27
+ """
28
+ The strategy suggested to us by Mitchell Wortsman from UW.
29
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
30
+ on the size of the weights as well as the depth of the layer.
31
+ """
32
+
33
+ normal = "normal"
34
+ """
35
+ All weights are initialized from the same normal distribution.
36
+ """
37
+
38
+ default = "default"
39
+ """
40
+ All weights are initialized with the default HuggingFace Bert method. Set init_std=0.02 to match.
41
+ """
42
+
43
+ kaiming_normal = "kaiming_normal"
44
+ """
45
+ All weights are initialized with the Kaiming method from a normal distribution.
46
+ Note this currently won't work with FSDP.
47
+ """
48
+
49
+ fan_in = "fan_in"
50
+ """
51
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
52
+ is the input dimensionality of the kernel.
53
+ """
54
+
55
+ full_megatron = "full_megatron"
56
+ """
57
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
58
+ """
59
+
60
+
61
+ class ModuleType(StrEnum):
62
+ in_module = "in"
63
+ out_module = "out"
64
+ emb = "emb"
65
+ final_out = "final_out"
66
+
67
+
68
+ def init_weights(
69
+ config: FlexBertConfig,
70
+ module: Union[nn.Linear, nn.Embedding],
71
+ layer_dim: Optional[int] = None,
72
+ layer_id: Optional[int] = None,
73
+ std_factor: float = 1.0,
74
+ type_of_module: Optional[ModuleType] = None,
75
+ ) -> None:
76
+ """
77
+ Initialize weights of a linear or embedding module.
78
+
79
+ :param config: The model config.
80
+ :param module: The linear or embedding submodule to initialize.
81
+ :param layer_dim: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
82
+ for fused layers.
83
+ :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
84
+ ``1 / sqrt(2 * (layer_id + 1))``.
85
+ """
86
+ if config.init_method == InitFnType.full_megatron and config.init_small_embedding:
87
+ raise ValueError("Cannot use 'small_embedding_init' with 'full_megatron' init.")
88
+
89
+ layer_dim = layer_dim if layer_dim is not None else config.hidden_size
90
+ if config.init_method == InitFnType.normal:
91
+ std = config.init_std * std_factor
92
+ if config.init_cutoff_factor is not None:
93
+ cutoff_value = config.init_cutoff_factor * std
94
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
95
+ else:
96
+ nn.init.normal_(module.weight, mean=0.0, std=std)
97
+ elif config.init_method == InitFnType.mitchell:
98
+ std = std_factor / math.sqrt(layer_dim)
99
+ if layer_id is not None:
100
+ std = std / math.sqrt(2 * (layer_id + 1))
101
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
102
+ elif config.init_method == InitFnType.kaiming_normal:
103
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
104
+ elif config.init_method == InitFnType.fan_in:
105
+ std = std_factor / math.sqrt(layer_dim)
106
+ nn.init.normal_(module.weight, mean=0.0, std=std)
107
+ elif config.init_method == InitFnType.full_megatron:
108
+ if type_of_module is None:
109
+ raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
110
+
111
+ cutoff_factor = config.init_cutoff_factor
112
+ if cutoff_factor is None:
113
+ cutoff_factor = 3
114
+
115
+ if type_of_module == ModuleType.in_module:
116
+ # for att_proj (same as QKV), ff_proj
117
+ std = config.init_std
118
+ elif type_of_module == ModuleType.out_module:
119
+ # for attn_out, ff_out
120
+ std = config.init_std / math.sqrt(2.0 * config.num_hidden_layers)
121
+ elif type_of_module == ModuleType.emb:
122
+ # positional embeddings (wpe)
123
+ # token embeddings (wte)
124
+ std = config.init_std
125
+ elif type_of_module == ModuleType.final_out:
126
+ # final output (ff_out)
127
+ std = config.hidden_size**-0.5
128
+ else:
129
+ raise RuntimeError(f"Unknown module type '{type_of_module}'")
130
+
131
+ nn.init.trunc_normal_(
132
+ module.weight,
133
+ mean=0.0,
134
+ std=std,
135
+ a=-cutoff_factor * std,
136
+ b=cutoff_factor * std,
137
+ )
138
+ elif config.init_method == InitFnType.default:
139
+ # default hugging face bert initialization
140
+ # normalization layers already init to ones and zeros
141
+ if isinstance(module, nn.Linear):
142
+ # Slightly different from the TF version which uses truncated_normal for initialization
143
+ # cf https://github.com/pytorch/pytorch/pull/5617
144
+ module.weight.data.normal_(mean=0.0, std=config.init_std)
145
+ if module.bias is not None:
146
+ module.bias.data.zero_()
147
+ elif isinstance(module, nn.Embedding):
148
+ module.weight.data.normal_(mean=0.0, std=config.init_std)
149
+ if module.padding_idx is not None:
150
+ module.weight.data[module.padding_idx].zero_()
151
+ else:
152
+ raise NotImplementedError(config.init_method)
153
+
154
+ if isinstance(module, nn.Linear):
155
+ if module.bias is not None:
156
+ nn.init.zeros_(module.bias)
157
+
158
+ if config.init_method == InitFnType.normal and getattr(module, "_is_residual", False):
159
+ with torch.no_grad():
160
+ module.weight.div_(math.sqrt(2 * config.num_hidden_layers))
161
+
162
+ if isinstance(module, nn.Embedding) and config.init_small_embedding:
163
+ nn.init.uniform_(module.weight, a=-1e-4, b=1e-4)
164
+
165
+
166
+ class TileMode(StrEnum):
167
+ center_weights = "center_weights"
168
+ tile_weights_from_edge = "tile_weights_from_edge"
169
+ tile_weights_from_middle = "tile_weights_from_middle"
170
+
171
+
172
+ def tile_weight(
173
+ pretrained_weights: torch.Tensor,
174
+ new_weights: torch.Tensor,
175
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
176
+ ) -> torch.Tensor:
177
+ """
178
+ Tile or center an input tensor to a larger desired size. Works for both 2D and 1D tensors.
179
+
180
+ Args:
181
+ pretrained_weights (torch.Tensor): The input tensor to be tiled or centered (1D or 2D).
182
+ new_weights (torch.Tensor): The tensor with the desired size.
183
+ mode (Union[str, TileMode]): 'center_weights', 'tile_weights_from_edge', or 'tile_weights_from_middle'
184
+
185
+ Returns:
186
+ torch.Tensor: The resulting tensor of the desired size.
187
+ """
188
+ assert pretrained_weights.dim() in (1, 2), "Input tensor must be 1-dimensional or 2-dimensional"
189
+ if isinstance(mode, str):
190
+ mode = TileMode(mode)
191
+
192
+ pretrained_weights = pretrained_weights.clone()
193
+
194
+ if pretrained_weights.dim() == 1:
195
+ return _tile_1d(pretrained_weights, new_weights, mode)
196
+ else:
197
+ return _tile_2d(pretrained_weights, new_weights, mode)
198
+
199
+
200
+ def _tile_1d(pretrained_weights: torch.Tensor, new_weights: torch.Tensor, mode: TileMode) -> torch.Tensor:
201
+ assert pretrained_weights.dim() == 1, "Input tensor must be 1-dimensional"
202
+ input_size = pretrained_weights.shape[0]
203
+ new_size = new_weights.shape[0]
204
+ assert new_size >= input_size, "Desired size must be greater than or equal to input size"
205
+
206
+ if mode == TileMode.center_weights:
207
+ offset = (new_size - input_size) // 2
208
+ new_weights[offset : offset + input_size] = pretrained_weights
209
+ return new_weights.clone()
210
+ elif mode == TileMode.tile_weights_from_edge:
211
+ repeat_count = (new_size + input_size - 1) // input_size
212
+ tiled_tensor = pretrained_weights.repeat(repeat_count)
213
+ return tiled_tensor[:new_size].clone()
214
+ elif mode == TileMode.tile_weights_from_middle:
215
+ # Calculate offsets to center the original tensor
216
+ offset = (new_size - input_size) // 2
217
+
218
+ # Create a new tensor with the desired size
219
+ result = torch.zeros(new_size, dtype=pretrained_weights.dtype, device=pretrained_weights.device)
220
+
221
+ # Place the original tensor in the center
222
+ result[offset : offset + input_size] = pretrained_weights
223
+
224
+ # Tile the left and right sides
225
+ for i in range(offset):
226
+ result[offset - 1 - i] = pretrained_weights[input_size - 1 - (i % input_size)]
227
+ for i in range(offset + input_size, new_size):
228
+ result[i] = pretrained_weights[(i - offset) % input_size]
229
+ return result.clone()
230
+
231
+
232
+ def _tile_2d(pretrained_weights: torch.Tensor, new_weights: torch.Tensor, mode: TileMode) -> torch.Tensor:
233
+ assert pretrained_weights.dim() == 2, "Input tensor must be 2-dimensional"
234
+ input_height, input_width = pretrained_weights.shape
235
+ new_height, new_width = new_weights.shape
236
+ assert new_height >= input_height, "Desired height must be greater than or equal to input height"
237
+ assert new_width >= input_width, "Desired width must be greater than or equal to input width"
238
+
239
+ if mode == TileMode.center_weights:
240
+ height_offset = (new_height - input_height) // 2
241
+ width_offset = (new_width - input_width) // 2
242
+ new_weights[height_offset : height_offset + input_height, width_offset : width_offset + input_width] = pretrained_weights # fmt: skip
243
+ return new_weights.clone()
244
+ elif mode == TileMode.tile_weights_from_edge:
245
+ repeat_height = (new_height + input_height - 1) // input_height
246
+ repeat_width = (new_width + input_width - 1) // input_width
247
+ tiled_tensor = pretrained_weights.repeat(repeat_height, repeat_width)
248
+ return tiled_tensor[:new_height, :new_width].clone()
249
+ elif mode == TileMode.tile_weights_from_middle:
250
+ # Calculate offsets to center the original tensor
251
+ height_offset = (new_height - input_height) // 2
252
+ width_offset = (new_width - input_width) // 2
253
+
254
+ # Create a new tensor with the desired width and input height
255
+ horizontal_tiled = torch.zeros(
256
+ input_height, new_width, dtype=pretrained_weights.dtype, device=pretrained_weights.device
257
+ )
258
+
259
+ # Place the original tensor in the center horizontally
260
+ horizontal_tiled[:, width_offset : width_offset + input_width] = pretrained_weights
261
+
262
+ # Tile the left and right sides
263
+ for i in range(width_offset):
264
+ horizontal_tiled[:, i] = horizontal_tiled[
265
+ :, width_offset + input_width - 1 - (width_offset - i - 1) % input_width
266
+ ]
267
+ for i in range(width_offset + input_width, new_width):
268
+ horizontal_tiled[:, i] = horizontal_tiled[:, width_offset + (i - width_offset) % input_width]
269
+
270
+ # Now tile vertically
271
+ result = torch.zeros(new_height, new_width, dtype=pretrained_weights.dtype, device=pretrained_weights.device)
272
+ result[height_offset : height_offset + input_height, :] = horizontal_tiled
273
+
274
+ # Tile top
275
+ for i in range(height_offset):
276
+ row_to_copy = (input_height - 1) - (i % input_height)
277
+ result[height_offset - 1 - i, :] = horizontal_tiled[row_to_copy, :]
278
+
279
+ # Tile bottom
280
+ for i in range(height_offset + input_height, new_height):
281
+ row_to_copy = (i - height_offset) % input_height
282
+ result[i, :] = horizontal_tiled[row_to_copy, :]
283
+ return result.clone()
284
+
285
+
286
+ def tile_fused_qkv(
287
+ pretrained_qkv_weight: torch.Tensor,
288
+ new_qkv_weight: torch.Tensor,
289
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
290
+ ):
291
+ """
292
+ Tile the weights of a fused pretrained QKV layer to a new, larger QKV dimension.
293
+
294
+ Args:
295
+ pretrained_qkv_weight (torch.Tensor): The original fused QKV layer
296
+ new_qkv_weight (torch.Tensor): The new fused QKV layer with larger linear_dim
297
+ mode (Union[str, TileMode]): The tiling mode to use
298
+ Returns:
299
+ torch.Tensor: The new fused QKV layer with tiled weights
300
+ """
301
+ # Split QKV, assume new_q, new_k, new_v are the same shape
302
+ pretrained_q, pretrained_k, pretrained_v = pretrained_qkv_weight.chunk(3, dim=0)
303
+ new_q, new_k, new_v = new_qkv_weight.chunk(3, dim=0)
304
+
305
+ # Tile Q, K, V separately
306
+ new_q = tile_weight(pretrained_q, new_q, mode=mode)
307
+ new_k = tile_weight(pretrained_k, new_k, mode=mode)
308
+ new_v = tile_weight(pretrained_v, new_v, mode=mode)
309
+
310
+ # Concatenate tiled Q, K, V
311
+ return torch.cat([new_q, new_k, new_v], dim=0)
312
+
313
+
314
+ def tile_fused_glu(
315
+ pretrained_glu_weight: torch.Tensor,
316
+ new_glu_weight: torch.Tensor,
317
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
318
+ ):
319
+ """
320
+ Tile the weights of a fused pretrained GLU layer to a new, larger GLU dimension.
321
+
322
+ Args:
323
+ pretrained_glu_weight (torch.Tensor): The original fused GLU layer
324
+ new_glu_weight (torch.Tensor): The new fused GLU layer with larger linear_dim
325
+ mode (Union[str, TileMode]): The tiling mode to use
326
+ Returns:
327
+ torch.Tensor: The new fused GLU layer with tiled weights
328
+ """
329
+ # Split GLU, assume new_glu_wi, new_glu_wg are the same shape
330
+ pretrained_glu_wi, pretrained_glu_wg = pretrained_glu_weight.chunk(2, dim=0)
331
+ new_glu_wi, new_glu_wg = new_glu_weight.chunk(2, dim=0)
332
+
333
+ # Tile GLU separately
334
+ new_glu_wi = tile_weight(pretrained_glu_wi, new_glu_wi, mode=mode)
335
+ new_glu_wg = tile_weight(pretrained_glu_wg, new_glu_wg, mode=mode)
336
+
337
+ # Concatenate tiled GLU
338
+ return torch.cat([new_glu_wi, new_glu_wg], dim=0)
339
+
340
+
341
+ def tile_fused_qkvff(
342
+ pretrained_qkvff_weight: torch.Tensor,
343
+ new_qkvff_weight: torch.Tensor,
344
+ pretrained_attn_size: int,
345
+ pretrained_mlp_size: int,
346
+ new_attn_size: int,
347
+ new_mlp_size: int,
348
+ is_glu: bool = False,
349
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
350
+ ):
351
+ """
352
+ Tile the weights of a fused pretrained QKVFF layer to a new, larger QKVFF dimension.
353
+
354
+ Args:
355
+ pretrained_qkvff_weight (torch.Tensor): The original fused QKVFF layer
356
+ new_qkvff_weight (torch.Tensor): The new fused QKVFF layer with larger linear_dim
357
+ pretrained_attn_size (int): The attention size of the pretrained fused QKVFF layer
358
+ pretrained_mlp_size (int): The mlp size of the pretrained fused QKVFF layer
359
+ new_attn_size (int): The attention size of the new fused QKVFF layer
360
+ new_mlp_size (int): The mlp size of the new fused QKVFF layer
361
+ is_glu (bool): Whether the QKVFF layer is a GLU layer
362
+ mode (Union[str, TileMode]): The tiling mode to use
363
+ Returns:
364
+ torch.Tensor: The new fused QKVFF layer with tiled weights
365
+ """
366
+ # Split QKVFF
367
+ pretrained_qkv, pretrained_ff = pretrained_qkvff_weight.split([pretrained_attn_size, pretrained_mlp_size], dim=0)
368
+ new_qkv, new_ff = new_qkvff_weight.split([new_attn_size, new_mlp_size], dim=0)
369
+
370
+ # Tile QKVFF separately
371
+ new_qkv = tile_fused_qkv(pretrained_qkv, new_qkv, mode=mode)
372
+ if is_glu:
373
+ new_ff = tile_fused_glu(pretrained_ff, new_ff, mode=mode)
374
+ else:
375
+ new_ff = tile_weight(pretrained_ff, new_ff, mode=mode)
376
+
377
+ # Concatenate tiled QKVFF
378
+ return torch.cat([new_qkv, new_ff], dim=0)
379
+
380
+
381
+ class TileLinear(StrEnum):
382
+ wqkv = "wqkv"
383
+ glu = "glu"
384
+ wqkvff = "wqkvff"
385
+ default = "default"
386
+
387
+
388
+ def tile_linear(
389
+ pretrained_linear: nn.Linear,
390
+ new_linear: nn.Linear,
391
+ linear_type: Union[str, TileLinear] = TileLinear.default,
392
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
393
+ pretrained_attn_size: Optional[int] = None,
394
+ pretrained_mlp_size: Optional[int] = None,
395
+ new_attn_size: Optional[int] = None,
396
+ new_mlp_size: Optional[int] = None,
397
+ wqkvff_is_glu: Optional[bool] = None,
398
+ bias_only: Optional[bool] = False,
399
+ ):
400
+ """
401
+ Tile the weights of a linear layer to a new, larger linear dimension.
402
+
403
+ Args:
404
+ pretrained_linear (nn.Linear): The original linear layer
405
+ new_linear (nn.Linear): The new linear layer with larger linear_dim
406
+ linear_type (Union[str, TileLinear]): The type of linear layer to tile
407
+ mode (Union[str, TileMode]): The tiling mode to use
408
+ pretrained_attn_size (int): The attention size of the pretrained linear layer. Only used if linear_type is wqkvff.
409
+ pretrained_mlp_size (int): The mlp size of the pretrained linear layer. Only used if linear_type is wqkvff.
410
+ new_attn_size (int): The attention size of the new linear layer. Only used if linear_type is wqkvff.
411
+ new_mlp_size (int): The mlp size of the new linear layer. Only used if linear_type is wqkvff.
412
+ wqkvff_is_glu (bool): Whether the wqkvff layer is a GLU layer. Only used if linear_type is wqkvff.
413
+ bias_only (bool): Whether to only tile the bias. Only used if tiling weight tied decoder.
414
+ """
415
+ if isinstance(linear_type, str):
416
+ linear_type = TileLinear(linear_type)
417
+ if isinstance(mode, str):
418
+ mode = TileMode(mode)
419
+
420
+ with torch.no_grad():
421
+ if linear_type == TileLinear.wqkv:
422
+ if not bias_only:
423
+ new_linear.weight = nn.Parameter(
424
+ tile_fused_qkv(pretrained_linear.weight, new_linear.weight, mode=mode),
425
+ requires_grad=new_linear.weight.requires_grad,
426
+ )
427
+ if pretrained_linear.bias is not None:
428
+ new_linear.bias = nn.Parameter(
429
+ tile_fused_qkv(pretrained_linear.bias, new_linear.bias, mode=mode),
430
+ requires_grad=new_linear.bias.requires_grad,
431
+ )
432
+ elif linear_type == TileLinear.glu:
433
+ if not bias_only:
434
+ new_linear.weight = nn.Parameter(
435
+ tile_fused_glu(pretrained_linear.weight, new_linear.weight, mode=mode),
436
+ requires_grad=new_linear.weight.requires_grad,
437
+ )
438
+ if pretrained_linear.bias is not None:
439
+ new_linear.bias = nn.Parameter(
440
+ tile_fused_glu(pretrained_linear.bias, new_linear.bias, mode=mode),
441
+ requires_grad=new_linear.bias.requires_grad,
442
+ )
443
+ elif linear_type == TileLinear.wqkvff:
444
+ if not bias_only:
445
+ new_linear.weight = nn.Parameter(
446
+ tile_fused_qkvff(
447
+ pretrained_linear.weight,
448
+ new_linear.weight,
449
+ pretrained_attn_size,
450
+ pretrained_mlp_size,
451
+ new_attn_size,
452
+ new_mlp_size,
453
+ wqkvff_is_glu,
454
+ mode=mode,
455
+ ),
456
+ requires_grad=new_linear.weight.requires_grad,
457
+ )
458
+ if pretrained_linear.bias is not None:
459
+ new_linear.bias = nn.Parameter(
460
+ tile_fused_qkvff(
461
+ pretrained_linear.bias,
462
+ new_linear.bias,
463
+ pretrained_attn_size,
464
+ pretrained_mlp_size,
465
+ new_attn_size,
466
+ new_mlp_size,
467
+ wqkvff_is_glu,
468
+ mode=mode,
469
+ ),
470
+ requires_grad=new_linear.bias.requires_grad,
471
+ )
472
+ else:
473
+ if not bias_only:
474
+ new_linear.weight = nn.Parameter(
475
+ tile_weight(pretrained_linear.weight, new_linear.weight, mode=mode),
476
+ requires_grad=new_linear.weight.requires_grad,
477
+ )
478
+ if pretrained_linear.bias is not None:
479
+ new_linear.bias = nn.Parameter(
480
+ tile_weight(pretrained_linear.bias, new_linear.bias, mode=mode),
481
+ requires_grad=new_linear.bias.requires_grad,
482
+ )
483
+
484
+
485
+ def tile_norm(
486
+ pretrained_norm: Union[nn.LayerNorm, RMSNorm, nn.Identity],
487
+ new_norm: Union[nn.LayerNorm, RMSNorm, nn.Identity],
488
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
489
+ ):
490
+ """
491
+ Tile the weights of a pretrained norm layer to a new, larger layer norm dimension.
492
+
493
+ Args:
494
+ pretrained_norm (Union[nn.LayerNorm, RMSNorm, nn.Identity]): The original norm layer
495
+ new_norm (Union[nn.LayerNorm, RMSNorm, nn.Identity]): The new norm layer with larger layer norm dimension
496
+ mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
497
+ """
498
+ if isinstance(pretrained_norm, nn.Identity):
499
+ return
500
+ if isinstance(mode, str):
501
+ mode = TileMode(mode)
502
+
503
+ with torch.no_grad():
504
+ new_norm.weight.data = nn.Parameter(
505
+ tile_weight(pretrained_norm.weight, new_norm.weight, mode=mode),
506
+ requires_grad=new_norm.weight.requires_grad,
507
+ )
508
+ if hasattr(pretrained_norm, "bias") and pretrained_norm.bias is not None:
509
+ new_norm.bias.data = nn.Parameter(
510
+ tile_weight(pretrained_norm.bias, new_norm.bias, mode=mode),
511
+ requires_grad=new_norm.bias.requires_grad,
512
+ )
513
+
514
+
515
+ def tile_embedding(
516
+ pretrained_embedding: nn.Embedding,
517
+ new_embedding: nn.Embedding,
518
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
519
+ ) -> nn.Embedding:
520
+ """
521
+ Tile the weights of an embedding layer to a new, larger embedding dimension.
522
+
523
+ Args:
524
+ pretrained_embedding (nn.Embedding): The original embedding layer
525
+ new_embedding (nn.Embedding): The new embedding layer with larger embedding_dim
526
+ tile_mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
527
+
528
+ Returns:
529
+ nn.Embedding: The new embedding layer with tiled weights
530
+ """
531
+ with torch.no_grad():
532
+ # Ensure vocabulary size remains the same
533
+ if pretrained_embedding.num_embeddings != new_embedding.num_embeddings:
534
+ raise ValueError("Vocabulary size (num_embeddings) must remain constant")
535
+
536
+ # Ensure new embedding dimension is larger
537
+ if new_embedding.embedding_dim <= pretrained_embedding.embedding_dim:
538
+ raise ValueError("New embedding_dim must be larger than the old embedding_dim")
539
+
540
+ # Tile the weights
541
+ new_embedding.weight.data = nn.Parameter(
542
+ tile_weight(pretrained_embedding.weight, new_embedding.weight, mode=mode),
543
+ requires_grad=new_embedding.weight.requires_grad,
544
+ )
545
+
546
+ # Handle padding_idx if it exists
547
+ if pretrained_embedding.padding_idx is not None:
548
+ if new_embedding.padding_idx is None:
549
+ new_embedding.padding_idx = pretrained_embedding.padding_idx
550
+ else:
551
+ assert new_embedding.padding_idx == pretrained_embedding.padding_idx, "padding_idx must remain the same"
layers.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2022 MosaicML Examples authors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+
7
+ # Copyright 2023 MosaicML Examples authors
8
+ # SPDX-License-Identifier: Apache-2.0
9
+
10
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
11
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
12
+ # Copyright (c) 2023, Tri Dao.
13
+
14
+
15
+ import copy
16
+ import math
17
+ import warnings
18
+ from typing import Optional, Union, List
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ import bert_padding
24
+
25
+ from activation import get_act_fn
26
+ from attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer
27
+ from mlp import FlexBertMLPBase, BertResidualGLU, get_mlp_layer
28
+ from configuration_bert import FlexBertConfig, maybe_add_padding
29
+ from normalization import get_norm_layer
30
+ from initialization import ModuleType, init_weights
31
+
32
+
33
+ class BertAlibiLayer(nn.Module):
34
+ """Composes the Mosaic BERT attention and FFN blocks into a single layer."""
35
+
36
+ def __init__(self, config):
37
+ super().__init__()
38
+ self.attention = BertAlibiUnpadAttention(config)
39
+ self.mlp = BertResidualGLU(config)
40
+
41
+ def forward(
42
+ self,
43
+ hidden_states: torch.Tensor,
44
+ cu_seqlens: torch.Tensor,
45
+ seqlen: int,
46
+ subset_idx: Optional[torch.Tensor] = None,
47
+ indices: Optional[torch.Tensor] = None,
48
+ attn_mask: Optional[torch.Tensor] = None,
49
+ bias: Optional[torch.Tensor] = None,
50
+ slopes: Optional[torch.Tensor] = None,
51
+ ) -> torch.Tensor:
52
+ """Forward pass for a BERT layer, including both attention and MLP.
53
+
54
+ Args:
55
+ hidden_states: (total_nnz, dim)
56
+ cu_seqlens: (batch + 1,)
57
+ seqlen: int
58
+ subset_idx: () set of indices whose values we care about at the end of the layer
59
+ (e.g., the masked tokens, if this is the final layer).
60
+ indices: None or (total_nnz,)
61
+ attn_mask: None or (batch, max_seqlen_in_batch)
62
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
63
+ slopes: None or (batch, heads) or (heads,)
64
+ """
65
+ assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}"
66
+ attention_output = self.attention(
67
+ hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias, slopes
68
+ )
69
+ layer_output = self.mlp(attention_output)
70
+ return layer_output
71
+
72
+
73
+ class BertAlibiEncoder(nn.Module):
74
+ """A stack of BERT layers providing the backbone of Mosaic BERT.
75
+
76
+ This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`,
77
+ but with substantial modifications to implement unpadding and ALiBi.
78
+
79
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
80
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
81
+ """
82
+
83
+ def __init__(self, config):
84
+ super().__init__()
85
+ layer = BertAlibiLayer(config)
86
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
87
+
88
+ self.num_attention_heads = config.num_attention_heads
89
+
90
+ # The alibi mask will be dynamically expanded if it is too small for
91
+ # the input the model receives. But it generally helps to initialize it
92
+ # to a reasonably large size to help pre-allocate CUDA memory.
93
+ # The default `alibi_starting_size` is 512.
94
+ self._current_alibi_size = int(config.alibi_starting_size)
95
+ self.alibi = torch.zeros((1, self.num_attention_heads, self._current_alibi_size, self._current_alibi_size))
96
+ self.rebuild_alibi_tensor(size=config.alibi_starting_size)
97
+
98
+ def rebuild_alibi_tensor(self, size: int, device: Optional[Union[torch.device, str]] = None):
99
+ # Alibi
100
+ # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
101
+ # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
102
+ # of the logits, which makes the math work out *after* applying causal masking. If no causal masking
103
+ # will be applied, it is necessary to construct the diagonal mask.
104
+ n_heads = self.num_attention_heads
105
+
106
+ def _get_alibi_head_slopes(n_heads: int) -> List[float]:
107
+ def get_slopes_power_of_2(n_heads: int) -> List[float]:
108
+ start = 2 ** (-(2 ** -(math.log2(n_heads) - 3)))
109
+ ratio = start
110
+ return [start * ratio**i for i in range(n_heads)]
111
+
112
+ # In the paper, they only train models that have 2^a heads for some a. This function
113
+ # has some good properties that only occur when the input is a power of 2. To
114
+ # maintain that even when the number of heads is not a power of 2, we use a
115
+ # workaround.
116
+ if math.log2(n_heads).is_integer():
117
+ return get_slopes_power_of_2(n_heads)
118
+
119
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
120
+ slopes_a = get_slopes_power_of_2(closest_power_of_2)
121
+ slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
122
+ slopes_b = slopes_b[0::2][: n_heads - closest_power_of_2]
123
+ return slopes_a + slopes_b
124
+
125
+ context_position = torch.arange(size, device=device)[:, None]
126
+ memory_position = torch.arange(size, device=device)[None, :]
127
+ relative_position = torch.abs(memory_position - context_position)
128
+ # [n_heads, max_token_length, max_token_length]
129
+ relative_position = relative_position.unsqueeze(0).expand(n_heads, -1, -1)
130
+ slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
131
+ self.slopes = slopes
132
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
133
+ # [1, n_heads, max_token_length, max_token_length]
134
+ alibi = alibi.unsqueeze(0)
135
+ assert alibi.shape == torch.Size([1, n_heads, size, size])
136
+
137
+ self._current_alibi_size = size
138
+ self.alibi = alibi
139
+
140
+ def forward(
141
+ self,
142
+ hidden_states: torch.Tensor,
143
+ attention_mask: torch.Tensor,
144
+ output_all_encoded_layers: Optional[bool] = True,
145
+ subset_mask: Optional[torch.Tensor] = None,
146
+ ) -> List[torch.Tensor]:
147
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
148
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
149
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
150
+
151
+ attention_mask_bool = attention_mask.bool()
152
+ batch, seqlen = hidden_states.shape[:2]
153
+ # Unpad inputs and mask. It will remove tokens that are padded.
154
+ # Assume ntokens is total number of tokens (padded and non-padded)
155
+ # and ntokens_unpad is total number of non-padded tokens.
156
+ # Then unpadding performs the following compression of the inputs:
157
+ # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
158
+ hidden_states, indices, cu_seqlens, _ = bert_padding.unpad_input(hidden_states, attention_mask_bool)
159
+
160
+ # Add alibi matrix to extended_attention_mask
161
+ if self._current_alibi_size < seqlen:
162
+ # Rebuild the alibi tensor when needed
163
+ warnings.warn(f"Increasing alibi size from {self._current_alibi_size} to {seqlen}")
164
+ self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
165
+ elif self.alibi.device != hidden_states.device:
166
+ # Device catch-up
167
+ self.alibi = self.alibi.to(hidden_states.device)
168
+ self.slopes = self.slopes.to(hidden_states.device) # type: ignore
169
+ alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
170
+ attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
171
+ alibi_attn_mask = attn_bias + alibi_bias
172
+
173
+ all_encoder_layers = []
174
+ if subset_mask is None:
175
+ for layer_module in self.layer:
176
+ hidden_states = layer_module(
177
+ hidden_states,
178
+ cu_seqlens,
179
+ seqlen,
180
+ None,
181
+ indices,
182
+ attn_mask=attention_mask,
183
+ bias=alibi_attn_mask,
184
+ slopes=self.slopes,
185
+ )
186
+ if output_all_encoded_layers:
187
+ all_encoder_layers.append(hidden_states)
188
+ # Pad inputs and mask. It will insert back zero-padded tokens.
189
+ # Assume ntokens is total number of tokens (padded and non-padded)
190
+ # and ntokens_unpad is total number of non-padded tokens.
191
+ # Then padding performs the following de-compression:
192
+ # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
193
+ hidden_states = bert_padding.pad_input(hidden_states, indices, batch, seqlen)
194
+ else:
195
+ for i in range(len(self.layer) - 1):
196
+ layer_module = self.layer[i]
197
+ hidden_states = layer_module(
198
+ hidden_states,
199
+ cu_seqlens,
200
+ seqlen,
201
+ None,
202
+ indices,
203
+ attn_mask=attention_mask,
204
+ bias=alibi_attn_mask,
205
+ slopes=self.slopes,
206
+ )
207
+ if output_all_encoded_layers:
208
+ all_encoder_layers.append(hidden_states)
209
+ subset_idx = torch.nonzero(subset_mask[attention_mask_bool], as_tuple=False).flatten()
210
+ hidden_states = self.layer[-1](
211
+ hidden_states,
212
+ cu_seqlens,
213
+ seqlen,
214
+ subset_idx=subset_idx,
215
+ indices=indices,
216
+ attn_mask=attention_mask,
217
+ bias=alibi_attn_mask,
218
+ slopes=self.slopes,
219
+ )
220
+
221
+ if not output_all_encoded_layers:
222
+ all_encoder_layers.append(hidden_states)
223
+ return all_encoder_layers
224
+
225
+
226
+ class BertPooler(nn.Module):
227
+ def __init__(self, config):
228
+ super(BertPooler, self).__init__()
229
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
230
+ self.activation = nn.Tanh()
231
+
232
+ def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor:
233
+ # We "pool" the model by simply taking the hidden state corresponding
234
+ # to the first token.
235
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
236
+ pooled_output = self.dense(first_token_tensor)
237
+ pooled_output = self.activation(pooled_output)
238
+ return pooled_output
239
+
240
+
241
+ class BertPredictionHeadTransform(nn.Module):
242
+ def __init__(self, config):
243
+ super().__init__()
244
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
245
+ if isinstance(config.hidden_act, str):
246
+ self.transform_act_fn = get_act_fn(config.head_pred_act)
247
+ else:
248
+ self.transform_act_fn = config.hidden_act
249
+ self.LayerNorm = get_norm_layer(config)
250
+
251
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
252
+ hidden_states = self.dense(hidden_states)
253
+ hidden_states = self.transform_act_fn(hidden_states)
254
+ hidden_states = self.LayerNorm(hidden_states)
255
+ return hidden_states
256
+
257
+
258
+ class FlexBertLayerBase(nn.Module):
259
+ """A FlexBERT Layer base class for type hints."""
260
+
261
+ attn: FlexBertAttentionBase
262
+ mlp: FlexBertMLPBase
263
+
264
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
265
+ super().__init__()
266
+ self.config = config
267
+ self.layer_id = layer_id
268
+
269
+ def _init_weights(self, reset_params: bool = False):
270
+ if hasattr(self, "attn"):
271
+ self.attn._init_weights(reset_params)
272
+ if hasattr(self, "mlp"):
273
+ self.mlp._init_weights(reset_params)
274
+
275
+ def reset_parameters(self):
276
+ self._init_weights(reset_params=True)
277
+
278
+ def forward(self, hidden_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
279
+ raise NotImplementedError("This is a base class and should not be used directly.")
280
+
281
+
282
+ class FlexBertCompileUnpadPreNormLayer(FlexBertLayerBase):
283
+ """Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
284
+
285
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
286
+ super().__init__(config=config, layer_id=layer_id)
287
+ if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
288
+ self.attn_norm = nn.Identity()
289
+ else:
290
+ self.attn_norm = get_norm_layer(config)
291
+ self.attn = get_attention_layer(config, layer_id=layer_id)
292
+ self.mlp_norm = get_norm_layer(config, compiled_norm=config.compile_model)
293
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
294
+ self.compile_model = config.compile_model
295
+
296
+ def _init_weights(self, reset_params: bool = False):
297
+ super()._init_weights(reset_params)
298
+ if reset_params:
299
+ self.attn_norm.reset_parameters()
300
+ self.mlp_norm.reset_parameters()
301
+
302
+ @torch.compile(dynamic=True)
303
+ def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
304
+ return self.mlp(self.mlp_norm(hidden_states))
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ cu_seqlens: torch.Tensor,
310
+ max_seqlen: int,
311
+ indices: Optional[torch.Tensor] = None,
312
+ attn_mask: Optional[torch.Tensor] = None,
313
+ ) -> torch.Tensor:
314
+ """Forward pass for a BERT layer, including both attention and MLP.
315
+
316
+ Args:
317
+ hidden_states: (total_nnz, dim)
318
+ cu_seqlens: (batch + 1,)
319
+ max_seqlen: int
320
+ indices: None or (total_nnz,)
321
+ attn_mask: None or (batch, max_seqlen)
322
+ """
323
+ attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), cu_seqlens, max_seqlen, indices, attn_mask)
324
+ return attn_out + self.compiled_mlp(attn_out)
325
+
326
+
327
+ class FlexBertUnpadPreNormLayer(FlexBertLayerBase):
328
+ """Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
329
+
330
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
331
+ super().__init__(config=config, layer_id=layer_id)
332
+ if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
333
+ self.attn_norm = nn.Identity()
334
+ else:
335
+ self.attn_norm = get_norm_layer(config)
336
+ self.attn = get_attention_layer(config, layer_id=layer_id)
337
+ self.mlp_norm = get_norm_layer(config)
338
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
339
+
340
+ def _init_weights(self, reset_params: bool = False):
341
+ super()._init_weights(reset_params)
342
+ if reset_params:
343
+ self.attn_norm.reset_parameters()
344
+ self.mlp_norm.reset_parameters()
345
+
346
+ def forward(
347
+ self,
348
+ hidden_states: torch.Tensor,
349
+ cu_seqlens: torch.Tensor,
350
+ max_seqlen: int,
351
+ indices: Optional[torch.Tensor] = None,
352
+ attn_mask: Optional[torch.Tensor] = None,
353
+ ) -> torch.Tensor:
354
+ """Forward pass for a BERT layer, including both attention and MLP.
355
+
356
+ Args:
357
+ hidden_states: (total_nnz, dim)
358
+ cu_seqlens: (batch + 1,)
359
+ max_seqlen: int
360
+ indices: None or (total_nnz,)
361
+ attn_mask: None or (batch, max_seqlen)
362
+ """
363
+ attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), cu_seqlens, max_seqlen, indices, attn_mask)
364
+ return attn_out + self.mlp(self.mlp_norm(attn_out))
365
+
366
+
367
+ class FlexBertUnpadParallelPreNormLayer(FlexBertLayerBase):
368
+ """Composes the FlexBERT parallel attention and MLP blocks into a single layer using pre-normalization."""
369
+
370
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
371
+ super().__init__(config=config, layer_id=layer_id)
372
+ self.attn_size = config.hidden_size * 3
373
+ self.mlp_size = config.intermediate_size * 2
374
+ # Compute QKV and FF outputs at once
375
+ self.Wqkvff = nn.Linear(config.hidden_size, self.attn_size + self.mlp_size, bias=config.attn_qkv_bias)
376
+ if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
377
+ self.norm = nn.Identity()
378
+ else:
379
+ self.norm = get_norm_layer(config)
380
+ self.attn = get_attention_layer(config, layer_id=layer_id)
381
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
382
+
383
+ def _init_weights(self, reset_params: bool = False):
384
+ super()._init_weights(reset_params)
385
+ if reset_params and hasattr(self.norm, "reset_parameters"):
386
+ self.norm.reset_parameters()
387
+
388
+ init_weights(
389
+ self.config,
390
+ self.Wqkvff,
391
+ layer_dim=self.config.hidden_size,
392
+ layer_id=None,
393
+ type_of_module=ModuleType.in_module,
394
+ )
395
+
396
+ def forward(
397
+ self,
398
+ hidden_states: torch.Tensor,
399
+ cu_seqlens: torch.Tensor,
400
+ max_seqlen: int,
401
+ indices: Optional[torch.Tensor] = None,
402
+ attn_mask: Optional[torch.Tensor] = None,
403
+ ) -> torch.Tensor:
404
+ """Forward pass for a BERT layer, including both attention and MLP.
405
+
406
+ Args:
407
+ hidden_states: (total_nnz, dim)
408
+ attn_mask: None or (batch, max_seqlen)
409
+ """
410
+ # Compute QKV and FF outputs at once and split them
411
+ qkv, intermediate_ff = self.Wqkvff(self.norm(hidden_states)).split([self.attn_size, self.mlp_size], dim=1)
412
+ return hidden_states + self.attn(qkv, cu_seqlens, max_seqlen, indices, attn_mask) + self.mlp(intermediate_ff)
413
+
414
+
415
+ class FlexBertPaddedPreNormLayer(FlexBertLayerBase):
416
+ """Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
417
+
418
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
419
+ super().__init__(config=config, layer_id=layer_id)
420
+ if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
421
+ self.attn_norm = nn.Identity()
422
+ else:
423
+ self.attn_norm = get_norm_layer(config)
424
+ self.attn = get_attention_layer(config, layer_id=layer_id)
425
+ self.mlp_norm = get_norm_layer(config)
426
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
427
+
428
+ def _init_weights(self, reset_params: bool = False):
429
+ super()._init_weights(reset_params)
430
+ if reset_params:
431
+ self.attn_norm.reset_parameters()
432
+ self.mlp_norm.reset_parameters()
433
+
434
+ def forward(
435
+ self,
436
+ hidden_states: torch.Tensor,
437
+ attn_mask: Optional[torch.Tensor] = None,
438
+ ) -> torch.Tensor:
439
+ """Forward pass for a BERT layer, including both attention and MLP.
440
+
441
+ Args:
442
+ hidden_states: (batch, max_seqlen, dim)
443
+ attn_mask: None or (batch, max_seqlen)
444
+ """
445
+ attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), attn_mask)
446
+ return attn_out + self.mlp(self.mlp_norm(attn_out))
447
+
448
+
449
+ class FlexBertPaddedParallelPreNormLayer(FlexBertLayerBase):
450
+ """Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
451
+
452
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
453
+ super().__init__(config=config, layer_id=layer_id)
454
+ self.attn_size = config.hidden_size * 3
455
+ self.mlp_size = config.intermediate_size * 2
456
+ # Compute QKV and FF outputs at once
457
+ self.Wqkvff = nn.Linear(config.hidden_size, self.attn_size + self.mlp_size, bias=config.attn_qkv_bias)
458
+ if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
459
+ self.norm = nn.Identity()
460
+ else:
461
+ self.norm = get_norm_layer(config)
462
+ self.attn = get_attention_layer(config, layer_id=layer_id)
463
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
464
+
465
+ def _init_weights(self, reset_params: bool = False):
466
+ super()._init_weights(reset_params)
467
+ if reset_params:
468
+ self.norm.reset_parameters()
469
+
470
+ init_weights(
471
+ self.config,
472
+ self.Wqkvff,
473
+ layer_dim=self.config.hidden_size,
474
+ layer_id=None,
475
+ type_of_module=ModuleType.in_module,
476
+ )
477
+
478
+ def forward(
479
+ self,
480
+ hidden_states: torch.Tensor,
481
+ attn_mask: Optional[torch.Tensor] = None,
482
+ ) -> torch.Tensor:
483
+ """Forward pass for a BERT layer, including both attention and MLP.
484
+
485
+ Args:
486
+ hidden_states: (batch, max_seqlen, dim)
487
+ attn_mask: None or (batch, max_seqlen)
488
+ """
489
+ # Compute QKV and FF outputs at once and split them
490
+ qkv, intermediate_ff = self.Wqkvff(self.norm(hidden_states)).split([self.attn_size, self.mlp_size], dim=2)
491
+ return hidden_states + self.attn(qkv, attn_mask) + self.mlp(intermediate_ff)
492
+
493
+
494
+ class FlexBertUnpadPostNormLayer(FlexBertLayerBase):
495
+ """Composes the FlexBERT attention and MLP blocks into a single layer using post-normalization."""
496
+
497
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
498
+ super().__init__(config=config, layer_id=layer_id)
499
+ self.attn = get_attention_layer(config, layer_id=layer_id)
500
+ self.attn_norm = get_norm_layer(config)
501
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
502
+ self.mlp_norm = get_norm_layer(config)
503
+
504
+ def _init_weights(self, reset_params: bool = False):
505
+ super()._init_weights(reset_params)
506
+ if reset_params:
507
+ self.attn_norm.reset_parameters()
508
+ self.mlp_norm.reset_parameters()
509
+
510
+ def forward(
511
+ self,
512
+ hidden_states: torch.Tensor,
513
+ cu_seqlens: torch.Tensor,
514
+ max_seqlen: int,
515
+ indices: Optional[torch.Tensor] = None,
516
+ attn_mask: Optional[torch.Tensor] = None,
517
+ ) -> torch.Tensor:
518
+ """Forward pass for a BERT layer, including both attention and MLP.
519
+
520
+ Args:
521
+ hidden_states: (total_nnz, dim)
522
+ cu_seqlens: (batch + 1,)
523
+ max_seqlen: int
524
+ indices: None or (total_nnz,)
525
+ attn_mask: None or (batch, max_seqlen)
526
+ """
527
+ attn_out = self.attn_norm(hidden_states + self.attn(hidden_states, cu_seqlens, max_seqlen, indices, attn_mask))
528
+ return self.mlp_norm(attn_out + self.mlp(attn_out))
529
+
530
+
531
+ class FlexBertPaddedPostNormLayer(FlexBertLayerBase):
532
+ """Composes the FlexBERT attention and MLP blocks into a single layer using post-normalization."""
533
+
534
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
535
+ super().__init__(config=config, layer_id=layer_id)
536
+ self.attn = get_attention_layer(config, layer_id=layer_id)
537
+ self.attn_norm = get_norm_layer(config)
538
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
539
+ self.mlp_norm = get_norm_layer(config)
540
+
541
+ def _init_weights(self, reset_params: bool = False):
542
+ super()._init_weights(reset_params)
543
+ if reset_params:
544
+ self.mlp_norm.reset_parameters()
545
+
546
+ def forward(
547
+ self,
548
+ hidden_states: torch.Tensor,
549
+ attn_mask: Optional[torch.Tensor] = None,
550
+ ) -> torch.Tensor:
551
+ """Forward pass for a BERT layer, including both attention and MLP.
552
+
553
+ Args:
554
+ hidden_states: (batch, max_seqlen, dim)
555
+ attn_mask: None or (batch, max_seqlen)
556
+ """
557
+ attn_out = self.attn_norm(hidden_states + self.attn(hidden_states, attn_mask))
558
+ return self.mlp_norm(attn_out + self.mlp(attn_out))
559
+
560
+
561
+ LAYER2CLS = {
562
+ "unpadded_prenorm": FlexBertUnpadPreNormLayer,
563
+ "unpadded_compile_prenorm": FlexBertCompileUnpadPreNormLayer,
564
+ "unpadded_parallel_prenorm": FlexBertUnpadParallelPreNormLayer,
565
+ "unpadded_postnorm": FlexBertUnpadPostNormLayer,
566
+ "padded_prenorm": FlexBertPaddedPreNormLayer,
567
+ "padded_parallel_prenorm": FlexBertPaddedParallelPreNormLayer,
568
+ "padded_postnorm": FlexBertPaddedPostNormLayer,
569
+ }
570
+
571
+
572
+ def get_bert_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertLayerBase:
573
+ try:
574
+ bert_layer = (
575
+ config.initial_bert_layer
576
+ if layer_id < config.num_initial_layers and getattr(config, "initial_bert_layer", None) is not None
577
+ else config.bert_layer
578
+ )
579
+ bert_layer = maybe_add_padding(config, bert_layer)
580
+ if config.compile_model and bert_layer == "unpadded_prenorm":
581
+ bert_layer = "unpadded_compile_prenorm"
582
+ return LAYER2CLS[bert_layer](config, layer_id=layer_id)
583
+ except KeyError:
584
+ if layer_id < config.num_initial_layers and getattr(config, "initial_bert_layer", None) is not None:
585
+ raise ValueError(
586
+ f"Invalid BERT layer type: {config.initial_bert_layer=}, must be one of {LAYER2CLS.keys()}."
587
+ f"{config.padding=} will be automatically prepended to `config.bert_layer` if unspecified."
588
+ )
589
+ else:
590
+ raise ValueError(
591
+ f"Invalid BERT layer type: {config.bert_layer=}, must be one of {LAYER2CLS.keys()}. "
592
+ f"{config.padding=} will be automatically prepended to `config.bert_layer` if unspecified."
593
+ )
594
+
595
+
596
+ class FlexBertEncoderBase(nn.Module):
597
+ """A FlexBERT base class for type hints."""
598
+
599
+ layers: nn.ModuleList
600
+
601
+ def _init_weights(self, reset_params: bool = False):
602
+ if hasattr(self, "layers"):
603
+ for layer in self.layers:
604
+ layer._init_weights(reset_params=reset_params)
605
+
606
+ def reset_parameters(self):
607
+ self._init_weights(reset_params=True)
608
+
609
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
610
+ raise NotImplementedError("This is a base class and should not be used directly.")
611
+
612
+
613
+ class FlexBertUnpadEncoder(FlexBertEncoderBase):
614
+ """A stack of BERT layers providing the backbone of FlexBERT.
615
+
616
+ This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`,
617
+ but with substantial modifications to implement unpadding and ALiBi.
618
+
619
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
620
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
621
+ """
622
+
623
+ def __init__(self, config: FlexBertConfig):
624
+ super().__init__()
625
+ self.layers = nn.ModuleList([get_bert_layer(config, layer_id=i) for i in range(config.num_hidden_layers)])
626
+ self.num_attention_heads = config.num_attention_heads
627
+
628
+ def forward(
629
+ self,
630
+ hidden_states: torch.Tensor,
631
+ attention_mask: torch.Tensor,
632
+ indices: Optional[torch.Tensor] = None,
633
+ cu_seqlens: Optional[torch.Tensor] = None,
634
+ max_seqlen: Optional[int] = None,
635
+ ) -> torch.Tensor:
636
+ if indices is None and cu_seqlens is None and max_seqlen is None:
637
+ attention_mask_bool = attention_mask.bool()
638
+ batch, seqlen = hidden_states.shape[:2]
639
+ hidden_states, indices, cu_seqlens, max_seqlen = bert_padding.unpad_input(
640
+ hidden_states, attention_mask_bool
641
+ )
642
+
643
+ for layer_module in self.layers:
644
+ hidden_states = layer_module(
645
+ hidden_states,
646
+ cu_seqlens,
647
+ max_seqlen,
648
+ indices,
649
+ attn_mask=attention_mask,
650
+ )
651
+
652
+ return bert_padding.pad_input(hidden_states, indices, batch, seqlen)
653
+ else:
654
+ for layer_module in self.layers:
655
+ hidden_states = layer_module(
656
+ hidden_states,
657
+ cu_seqlens,
658
+ max_seqlen,
659
+ indices,
660
+ attn_mask=attention_mask,
661
+ )
662
+ return hidden_states
663
+
664
+
665
+ class FlexBertPaddedEncoder(FlexBertEncoderBase):
666
+ """A stack of BERT layers providing the backbone of FlexBERT.
667
+
668
+ This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`,
669
+ but with substantial modifications to implement unpadding and ALiBi.
670
+
671
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
672
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
673
+ """
674
+
675
+ def __init__(self, config: FlexBertConfig):
676
+ super().__init__()
677
+ self.layers = nn.ModuleList([get_bert_layer(config, layer_id=i) for i in range(config.num_hidden_layers)])
678
+ self.num_attention_heads = config.num_attention_heads
679
+
680
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> torch.Tensor:
681
+ for layer_module in self.layers:
682
+ hidden_states = layer_module(hidden_states, attn_mask=attention_mask)
683
+
684
+ return hidden_states
685
+
686
+
687
+ ENC2CLS = {
688
+ "unpadded_base": FlexBertUnpadEncoder,
689
+ "padded_base": FlexBertPaddedEncoder,
690
+ }
691
+
692
+
693
+ def get_encoder_layer(config: FlexBertConfig) -> FlexBertEncoderBase:
694
+ try:
695
+ return ENC2CLS[maybe_add_padding(config, config.encoder_layer)](config)
696
+ except KeyError:
697
+ raise ValueError(
698
+ f"Invalid encoder layer type: {config.encoder_layer=}, must be one of {ENC2CLS.keys()}. "
699
+ f"{config.padding=} will be automatically prepended to `config.encoder_layer` if unspecified."
700
+ )
mlp.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2022 MosaicML Examples authors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+
7
+ # Copyright 2023 MosaicML Examples authors
8
+ # SPDX-License-Identifier: Apache-2.0
9
+
10
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
11
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
12
+ # Copyright (c) 2023, Tri Dao.
13
+
14
+ from typing import Optional
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from configuration_bert import FlexBertConfig
20
+ from activation import get_act_fn
21
+ from normalization import get_norm_layer
22
+ from initialization import ModuleType, init_weights
23
+
24
+
25
+ class BertResidualGLU(nn.Module):
26
+ """Applies the FFN at the end of each Mosaic BERT layer.
27
+
28
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
29
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but
30
+ introduces Gated Linear Units.
31
+
32
+ Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a
33
+ standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with
34
+ `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed
35
+ with the `config.intermediate_size=3072`.
36
+ However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased
37
+ parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ config,
43
+ ):
44
+ super().__init__()
45
+ self.config = config
46
+ self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False)
47
+ self.act = get_act_fn(config.hidden_act)
48
+ self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
49
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
50
+ self.layernorm = get_norm_layer(config)
51
+
52
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
53
+ """Compute new hidden states from current hidden states.
54
+
55
+ Args:
56
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
57
+ the attention layer [nnz, dim].
58
+ """
59
+ residual_connection = hidden_states
60
+ # compute the activation
61
+ hidden_states = self.gated_layers(hidden_states)
62
+ gated = hidden_states[:, : self.config.intermediate_size]
63
+ non_gated = hidden_states[:, self.config.intermediate_size :]
64
+ hidden_states = self.act(gated) * non_gated
65
+ hidden_states = self.dropout(hidden_states)
66
+ # multiply by the second matrix
67
+ hidden_states = self.wo(hidden_states)
68
+ # add the residual connection and post-LN
69
+ hidden_states = self.layernorm(hidden_states + residual_connection)
70
+ return hidden_states
71
+
72
+
73
+ class FlexBertMLPBase(nn.Module):
74
+ """A FlexBERT MLP base class for type hints."""
75
+
76
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
77
+ super().__init__()
78
+ self.config = config
79
+ self.layer_id = layer_id
80
+
81
+ def _init_weights(self, reset_params: bool = False):
82
+ raise NotImplementedError("This is a base class and should not be used directly.")
83
+
84
+ def reset_parameters(self):
85
+ self._init_weights(reset_params=True)
86
+
87
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
88
+ raise NotImplementedError("This is a base class and should not be used directly.")
89
+
90
+
91
+ class FlexBertMLP(FlexBertMLPBase):
92
+ """Applies the MLP at the end of each FlexBERT layer.
93
+
94
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
95
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
96
+ """
97
+
98
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
99
+ super().__init__(config=config, layer_id=layer_id)
100
+ self.Wi = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_in_bias)
101
+ self.act = get_act_fn(config.hidden_act)
102
+ self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity()
103
+ self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias)
104
+
105
+ def _init_weights(self, reset_params: bool = False):
106
+ init_weights(
107
+ self.config,
108
+ self.Wi,
109
+ layer_dim=self.config.hidden_size,
110
+ layer_id=None,
111
+ type_of_module=ModuleType.in_module,
112
+ )
113
+ init_weights(
114
+ self.config,
115
+ self.Wo,
116
+ layer_dim=self.config.intermediate_size,
117
+ layer_id=self.layer_id,
118
+ type_of_module=ModuleType.out_module,
119
+ )
120
+
121
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
122
+ """Compute new hidden states from current hidden states.
123
+
124
+ Args:
125
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
126
+ the attention layer [nnz, dim].
127
+ """
128
+ return self.Wo(self.drop(self.act(self.Wi(hidden_states))))
129
+
130
+
131
+ class FlexBertGLU(FlexBertMLPBase):
132
+ """Applies the GLU at the end of each FlexBERT layer.
133
+
134
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
135
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
136
+ """
137
+
138
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
139
+ super().__init__(config=config, layer_id=layer_id)
140
+ self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_in_bias)
141
+ self.act = get_act_fn(config.hidden_act)
142
+ self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity()
143
+ self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias)
144
+
145
+ def _init_weights(self, reset_params: bool = False):
146
+ init_weights(
147
+ self.config,
148
+ self.Wi,
149
+ layer_dim=self.config.hidden_size,
150
+ layer_id=None,
151
+ type_of_module=ModuleType.in_module,
152
+ )
153
+ init_weights(
154
+ self.config,
155
+ self.Wo,
156
+ layer_dim=self.config.intermediate_size,
157
+ layer_id=self.layer_id,
158
+ type_of_module=ModuleType.out_module,
159
+ )
160
+
161
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
162
+ input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
163
+ return self.Wo(self.drop(self.act(input) * gate))
164
+
165
+
166
+ class FlexBertParallelGLU(FlexBertMLPBase):
167
+ """Applies the GLU at the end of each FlexBERT layer using intermediate_ff computed in parallel of the attention.
168
+
169
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
170
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
171
+ """
172
+
173
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
174
+ super().__init__(config=config, layer_id=layer_id)
175
+ self.act = get_act_fn(config.hidden_act)
176
+ self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity()
177
+ self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias)
178
+
179
+ def _init_weights(self, reset_params: bool = False):
180
+ init_weights(
181
+ self.config,
182
+ self.Wo,
183
+ layer_dim=self.config.intermediate_size,
184
+ layer_id=self.layer_id,
185
+ type_of_module=ModuleType.out_module,
186
+ )
187
+
188
+ def forward(self, intermediate_ff: torch.Tensor) -> torch.Tensor:
189
+ input, gate = intermediate_ff.chunk(2, dim=-1)
190
+ return self.Wo(self.drop(self.act(input) * gate))
191
+
192
+
193
+ MLP2CLS = {
194
+ "mlp": FlexBertMLP,
195
+ "glu": FlexBertGLU,
196
+ "parallel_glu": FlexBertParallelGLU,
197
+ }
198
+
199
+
200
+ def get_mlp_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertMLPBase:
201
+ try:
202
+ mlp_layer = (
203
+ config.initial_mlp_layer
204
+ if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None
205
+ else config.mlp_layer
206
+ )
207
+ return MLP2CLS[mlp_layer](config, layer_id=layer_id)
208
+ except KeyError as e:
209
+ if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None:
210
+ raise ValueError(
211
+ f"Invalid MLP layer type: {config.initial_mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}"
212
+ )
213
+ else:
214
+ raise ValueError(f"Invalid MLP layer type: {config.mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}")
modeling_flexbert.py ADDED
@@ -0,0 +1,1684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation)
5
+ # License: LLAMA 2 COMMUNITY LICENSE AGREEMENT
6
+
7
+ # Copyright 2022 Jonas Geiping
8
+ # License: MIT
9
+
10
+ # Copyright 2022 MosaicML Examples authors
11
+ # SPDX-License-Identifier: Apache-2.0
12
+
13
+ # Copyright 2023 MosaicML Examples authors
14
+ # SPDX-License-Identifier: Apache-2.0
15
+
16
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
17
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
18
+ # Copyright (c) 2023, Tri Dao.
19
+
20
+ """Implements Mosaic BERT, with an eye towards the Hugging Face API.
21
+
22
+ Mosaic BERT improves performance over Hugging Face BERT through the following:
23
+
24
+ 1. ALiBi. This architectural change removes positional embeddings and instead encodes positional
25
+ information through attention biases based on query-key position distance. It improves the effectiveness
26
+ of training with shorter sequence lengths by enabling extrapolation to longer sequences.
27
+
28
+ 2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer
29
+ to improve overall expressiveness, providing better convergence properties.
30
+
31
+ 3. Flash Attention. The MosaicBERT's self-attention layer makes use of Flash Attention, which dramatically
32
+ improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that
33
+ supports attention biases, which allows us to use Flash Attention with ALiBi.
34
+
35
+ 4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT
36
+ implementations waste computation on padded tokens. MosaicBERT internally unpads to reduce unnecessary computation
37
+ and improve speed. It does this without changing how the user interfaces with the model, thereby
38
+ preserving the simple API of standard implementations.
39
+
40
+
41
+ Currently, MosaicBERT is available for masked language modeling :class:`BertForMaskedLM` and sequence
42
+ classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases.
43
+
44
+ See :file:`./mosaic_bert.py` for utilities to simplify working with MosaicBERT in Composer, and for example usage
45
+ of the core Mosaic BERT classes.
46
+ """
47
+
48
+ import logging
49
+ import os
50
+ import sys
51
+ import warnings
52
+ from dataclasses import dataclass
53
+ from typing import List, Optional, Tuple, Union
54
+
55
+ # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
56
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
57
+
58
+ import torch
59
+ import torch.nn as nn
60
+ from einops import rearrange
61
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
62
+ from transformers.modeling_outputs import (
63
+ MaskedLMOutput,
64
+ ModelOutput,
65
+ MultipleChoiceModelOutput,
66
+ SequenceClassifierOutput,
67
+ )
68
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
69
+
70
+ from bert_padding import index_put_first_axis
71
+
72
+ from activation import get_act_fn
73
+ from attention import (
74
+ FlexBertPaddedAttention,
75
+ FlexBertPaddedParallelAttention,
76
+ FlexBertPaddedRopeAttention,
77
+ FlexBertPaddedRopeParallelAttention,
78
+ FlexBertUnpadAttention,
79
+ FlexBertUnpadParallelAttention,
80
+ FlexBertUnpadRopeAttention,
81
+ FlexBertUnpadRopeParallelAttention,
82
+ )
83
+ from configuration_bert import FlexBertConfig
84
+ from embeddings import (
85
+ BertAlibiEmbeddings,
86
+ FlexBertAbsoluteEmbeddings,
87
+ FlexBertCompiledSansPositionEmbeddings,
88
+ FlexBertSansPositionEmbeddings,
89
+ get_embedding_layer,
90
+ )
91
+ from initialization import (
92
+ ModuleType,
93
+ TileLinear,
94
+ TileMode,
95
+ init_weights,
96
+ tile_embedding,
97
+ tile_linear,
98
+ tile_norm,
99
+ )
100
+ from layers import (
101
+ BertAlibiEncoder,
102
+ BertPooler,
103
+ BertPredictionHeadTransform,
104
+ FlexBertCompileUnpadPreNormLayer,
105
+ FlexBertPaddedEncoder,
106
+ FlexBertPaddedParallelPreNormLayer,
107
+ FlexBertPaddedPostNormLayer,
108
+ FlexBertPaddedPreNormLayer,
109
+ FlexBertUnpadEncoder,
110
+ FlexBertUnpadParallelPreNormLayer,
111
+ FlexBertUnpadPostNormLayer,
112
+ FlexBertUnpadPreNormLayer,
113
+ get_encoder_layer,
114
+ )
115
+ from mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
116
+ from normalization import get_norm_layer
117
+ from padding import pad_input, unpad_input
118
+
119
+ logger = logging.getLogger(__name__)
120
+
121
+
122
+ def _count_parameters(model: nn.Module, trainable: bool = True) -> int:
123
+ if trainable:
124
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
125
+ else:
126
+ return sum(p.numel() for p in model.parameters())
127
+
128
+
129
+ class BertModel(BertPreTrainedModel):
130
+ """Overall BERT model.
131
+
132
+ Args:
133
+ config: a BertConfig class instance with the configuration to build a new model
134
+
135
+ Inputs:
136
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
137
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
138
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
139
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
140
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
141
+ a `sentence B` token (see BERT paper for more details).
142
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
143
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
144
+ input sequence length in the current batch. It's the mask that we typically use for attention when
145
+ a batch has varying length sentences.
146
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
147
+
148
+ Outputs: Tuple of (encoded_layers, pooled_output)
149
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
150
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
151
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
152
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
153
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
154
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
155
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
156
+ classifier pretrained on top of the hidden state associated to the first character of the
157
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
158
+
159
+ Example usage:
160
+ ```python
161
+ # Already been converted into WordPiece token ids
162
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
163
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
164
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
165
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
166
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
167
+ model = BertModel(config=config)
168
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
169
+ ```
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ config,
175
+ add_pooling_layer: bool = True,
176
+ ):
177
+ super(BertModel, self).__init__(config)
178
+ self.embeddings = BertAlibiEmbeddings(config)
179
+ self.encoder = BertAlibiEncoder(config)
180
+ self.pooler = BertPooler(config) if add_pooling_layer else None
181
+ self.post_init()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.embeddings.word_embeddings
185
+
186
+ def set_input_embeddings(self, value):
187
+ self.embeddings.word_embeddings = value
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: torch.Tensor,
192
+ token_type_ids: Optional[torch.Tensor] = None,
193
+ attention_mask: Optional[torch.Tensor] = None,
194
+ position_ids: Optional[torch.Tensor] = None,
195
+ output_all_encoded_layers: Optional[bool] = False,
196
+ masked_tokens_mask: Optional[torch.Tensor] = None,
197
+ **kwargs,
198
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
199
+ if attention_mask is None:
200
+ attention_mask = torch.ones_like(input_ids)
201
+ if token_type_ids is None:
202
+ token_type_ids = torch.zeros_like(input_ids)
203
+
204
+ embedding_output = self.embeddings(input_ids, token_type_ids, position_ids)
205
+
206
+ subset_mask = []
207
+ first_col_mask = []
208
+
209
+ if masked_tokens_mask is None:
210
+ subset_mask = None
211
+ else:
212
+ first_col_mask = torch.zeros_like(masked_tokens_mask)
213
+ first_col_mask[:, 0] = True
214
+ subset_mask = masked_tokens_mask | first_col_mask
215
+
216
+ encoder_outputs = self.encoder(
217
+ embedding_output,
218
+ attention_mask,
219
+ output_all_encoded_layers=output_all_encoded_layers,
220
+ subset_mask=subset_mask,
221
+ )
222
+
223
+ if masked_tokens_mask is None:
224
+ sequence_output = encoder_outputs[-1]
225
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
226
+ else:
227
+ # TD [2022-03-01]: the indexing here is very tricky.
228
+ attention_mask_bool = attention_mask.bool()
229
+ subset_idx = subset_mask[attention_mask_bool] # type: ignore
230
+ sequence_output = encoder_outputs[-1][masked_tokens_mask[attention_mask_bool][subset_idx]]
231
+ if self.pooler is not None:
232
+ pool_input = encoder_outputs[-1][first_col_mask[attention_mask_bool][subset_idx]]
233
+ pooled_output = self.pooler(pool_input, pool=False)
234
+ else:
235
+ pooled_output = None
236
+
237
+ if not output_all_encoded_layers:
238
+ encoder_outputs = sequence_output
239
+
240
+ if self.pooler is not None:
241
+ return encoder_outputs, pooled_output
242
+
243
+ return encoder_outputs, None
244
+
245
+
246
+ ###################
247
+ # Bert Heads
248
+ ###################
249
+ class BertLMPredictionHead(nn.Module):
250
+ def __init__(self, config, bert_model_embedding_weights):
251
+ super().__init__()
252
+ self.transform = BertPredictionHeadTransform(config)
253
+ # The output weights are the same as the input embeddings, but there is
254
+ # an output-only bias for each token.
255
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0))
256
+ self.decoder.weight = bert_model_embedding_weights
257
+
258
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
259
+ hidden_states = self.transform(hidden_states)
260
+ hidden_states = self.decoder(hidden_states)
261
+ return hidden_states
262
+
263
+
264
+ class BertOnlyMLMHead(nn.Module):
265
+ def __init__(self, config, bert_model_embedding_weights):
266
+ super().__init__()
267
+ self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
268
+
269
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
270
+ prediction_scores = self.predictions(sequence_output)
271
+ return prediction_scores
272
+
273
+
274
+ class BertOnlyNSPHead(nn.Module):
275
+ def __init__(self, config):
276
+ super().__init__()
277
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
278
+
279
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
280
+ seq_relationship_score = self.seq_relationship(pooled_output)
281
+ return seq_relationship_score
282
+
283
+
284
+ #####################
285
+ # Various Bert models
286
+ #####################
287
+
288
+
289
+ class BertForPreTraining(BertPreTrainedModel):
290
+ # TBD: Coming in Future Commit
291
+ pass
292
+
293
+
294
+ class BertLMHeadModel(BertPreTrainedModel):
295
+ # TBD: Coming in Future Commit
296
+ pass
297
+
298
+
299
+ class BertForMaskedLM(BertPreTrainedModel):
300
+ def __init__(self, config):
301
+ super().__init__(config)
302
+
303
+ if config.is_decoder:
304
+ warnings.warn(
305
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
306
+ "bi-directional self-attention."
307
+ )
308
+
309
+ self.bert = BertModel(config, add_pooling_layer=False)
310
+ self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
311
+
312
+ # Initialize weights and apply final processing
313
+ self.post_init()
314
+
315
+ @classmethod
316
+ def from_composer(
317
+ cls,
318
+ pretrained_checkpoint,
319
+ state_dict=None,
320
+ cache_dir=None,
321
+ from_tf=False,
322
+ config=None,
323
+ *inputs,
324
+ **kwargs,
325
+ ):
326
+ """Load from pre-trained."""
327
+ model = cls(config, *inputs, **kwargs)
328
+ if from_tf:
329
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
330
+
331
+ state_dict = torch.load(pretrained_checkpoint)
332
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
333
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
334
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
335
+
336
+ if len(missing_keys) > 0:
337
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
338
+ if len(unexpected_keys) > 0:
339
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
340
+
341
+ return model
342
+
343
+ def get_output_embeddings(self):
344
+ return self.cls.predictions.decoder
345
+
346
+ def set_output_embeddings(self, new_embeddings):
347
+ self.cls.predictions.decoder = new_embeddings
348
+
349
+ def forward(
350
+ self,
351
+ input_ids: Optional[torch.Tensor] = None,
352
+ attention_mask: Optional[torch.Tensor] = None,
353
+ token_type_ids: Optional[torch.Tensor] = None,
354
+ position_ids: Optional[torch.Tensor] = None,
355
+ head_mask: Optional[torch.Tensor] = None,
356
+ inputs_embeds: Optional[torch.Tensor] = None,
357
+ encoder_hidden_states: Optional[torch.Tensor] = None,
358
+ encoder_attention_mask: Optional[torch.Tensor] = None,
359
+ labels: Optional[torch.Tensor] = None,
360
+ output_attentions: Optional[bool] = None,
361
+ output_hidden_states: Optional[bool] = None,
362
+ return_dict: Optional[bool] = None,
363
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
364
+ # labels should be a `torch.LongTensor` of shape
365
+ # `(batch_size, sequence_length)`. These are used for computing the
366
+ # masked language modeling loss.
367
+ #
368
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
369
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
370
+ # (masked), the loss is only computed for the tokens with labels in `[0,
371
+ # ..., config.vocab_size]`
372
+ #
373
+ # Prediction scores are only computed for masked tokens and the (bs,
374
+ # seqlen) dimensions are flattened
375
+ if (input_ids is not None) == (inputs_embeds is not None):
376
+ raise ValueError("Must specify either input_ids or input_embeds!")
377
+
378
+ if labels is None:
379
+ masked_tokens_mask = None
380
+ else:
381
+ masked_tokens_mask = labels > 0
382
+
383
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
384
+
385
+ outputs = self.bert(
386
+ input_ids,
387
+ attention_mask=attention_mask,
388
+ token_type_ids=token_type_ids,
389
+ position_ids=position_ids,
390
+ head_mask=head_mask,
391
+ inputs_embeds=inputs_embeds,
392
+ encoder_hidden_states=encoder_hidden_states,
393
+ encoder_attention_mask=encoder_attention_mask,
394
+ output_attentions=output_attentions,
395
+ output_hidden_states=output_hidden_states,
396
+ return_dict=return_dict,
397
+ masked_tokens_mask=masked_tokens_mask,
398
+ )
399
+
400
+ sequence_output = outputs[0]
401
+ prediction_scores = self.cls(sequence_output)
402
+
403
+ loss = None
404
+ if labels is not None:
405
+ # Compute loss
406
+ loss_fct = nn.CrossEntropyLoss()
407
+ masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
408
+ loss = loss_fct(prediction_scores, labels.flatten()[masked_token_idx])
409
+
410
+ assert input_ids is not None, "Coding error; please open an issue"
411
+ batch, seqlen = input_ids.shape[:2]
412
+ prediction_scores = rearrange(
413
+ index_put_first_axis(prediction_scores, masked_token_idx, batch * seqlen),
414
+ "(b s) d -> b s d",
415
+ b=batch,
416
+ )
417
+
418
+ if not return_dict:
419
+ output = (prediction_scores,) + outputs[2:]
420
+ return ((loss,) + output) if loss is not None else output
421
+
422
+ return MaskedLMOutput(
423
+ loss=loss,
424
+ logits=prediction_scores,
425
+ hidden_states=None,
426
+ attentions=None,
427
+ )
428
+
429
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs):
430
+ input_shape = input_ids.shape
431
+ effective_batch_size = input_shape[0]
432
+
433
+ # add a dummy token
434
+ if self.config.pad_token_id is None:
435
+ raise ValueError("The PAD token should be defined for generation")
436
+
437
+ attention_mask = torch.cat(
438
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
439
+ dim=-1,
440
+ )
441
+ dummy_token = torch.full(
442
+ (effective_batch_size, 1),
443
+ self.config.pad_token_id,
444
+ dtype=torch.long,
445
+ device=input_ids.device,
446
+ )
447
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
448
+
449
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
450
+
451
+
452
+ class BertForNextSentencePrediction(BertPreTrainedModel):
453
+ # TBD: Push in future commit
454
+ pass
455
+
456
+
457
+ class BertForSequenceClassification(BertPreTrainedModel):
458
+ """Bert Model transformer with a sequence classification/regression head.
459
+
460
+ This head is just a linear layer on top of the pooled output. Used for,
461
+ e.g., GLUE tasks.
462
+ """
463
+
464
+ def __init__(self, config):
465
+ super().__init__(config)
466
+ self.num_labels = config.num_labels
467
+ self.config = config
468
+
469
+ self.bert = BertModel(config)
470
+ classifier_dropout = (
471
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
472
+ )
473
+ self.dropout = nn.Dropout(classifier_dropout)
474
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
475
+
476
+ # Initialize weights and apply final processing
477
+ self.post_init()
478
+
479
+ @classmethod
480
+ def from_composer(
481
+ cls,
482
+ pretrained_checkpoint,
483
+ state_dict=None,
484
+ cache_dir=None,
485
+ from_tf=False,
486
+ config=None,
487
+ *inputs,
488
+ **kwargs,
489
+ ):
490
+ """Load from pre-trained."""
491
+ model = cls(config, *inputs, **kwargs)
492
+ if from_tf:
493
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
494
+
495
+ state_dict = torch.load(pretrained_checkpoint)
496
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
497
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
498
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
499
+
500
+ if len(missing_keys) > 0:
501
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
502
+ if len(unexpected_keys) > 0:
503
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
504
+
505
+ return model
506
+
507
+ def forward(
508
+ self,
509
+ input_ids: Optional[torch.Tensor] = None,
510
+ attention_mask: Optional[torch.Tensor] = None,
511
+ token_type_ids: Optional[torch.Tensor] = None,
512
+ position_ids: Optional[torch.Tensor] = None,
513
+ head_mask: Optional[torch.Tensor] = None,
514
+ inputs_embeds: Optional[torch.Tensor] = None,
515
+ labels: Optional[torch.Tensor] = None,
516
+ output_attentions: Optional[bool] = None,
517
+ output_hidden_states: Optional[bool] = None,
518
+ return_dict: Optional[bool] = None,
519
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
520
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
521
+ # Labels for computing the sequence classification/regression loss.
522
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
523
+ # If `config.num_labels == 1` a regression loss is computed
524
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
525
+ # is computed (cross-entropy).
526
+
527
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
528
+
529
+ outputs = self.bert(
530
+ input_ids,
531
+ attention_mask=attention_mask,
532
+ token_type_ids=token_type_ids,
533
+ position_ids=position_ids,
534
+ head_mask=head_mask,
535
+ inputs_embeds=inputs_embeds,
536
+ output_attentions=output_attentions,
537
+ output_hidden_states=output_hidden_states,
538
+ return_dict=return_dict,
539
+ )
540
+
541
+ pooled_output = outputs[1]
542
+
543
+ pooled_output = self.dropout(pooled_output)
544
+ logits = self.classifier(pooled_output)
545
+
546
+ loss = None
547
+ if labels is not None:
548
+ # Compute loss
549
+ if self.config.problem_type is None:
550
+ if self.num_labels == 1:
551
+ self.config.problem_type = "regression"
552
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
553
+ self.config.problem_type = "single_label_classification"
554
+ else:
555
+ self.config.problem_type = "multi_label_classification"
556
+
557
+ if self.config.problem_type == "regression":
558
+ loss_fct = nn.MSELoss()
559
+ if self.num_labels == 1:
560
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
561
+ else:
562
+ loss = loss_fct(logits, labels)
563
+ elif self.config.problem_type == "single_label_classification":
564
+ loss_fct = nn.CrossEntropyLoss()
565
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
566
+ elif self.config.problem_type == "multi_label_classification":
567
+ loss_fct = nn.BCEWithLogitsLoss()
568
+ loss = loss_fct(logits, labels)
569
+
570
+ if not return_dict:
571
+ output = (logits,) + outputs[2:]
572
+ return ((loss,) + output) if loss is not None else output
573
+
574
+ return SequenceClassifierOutput(
575
+ loss=loss,
576
+ logits=logits,
577
+ hidden_states=None,
578
+ attentions=None,
579
+ )
580
+
581
+
582
+ class BertForMultipleChoice(BertPreTrainedModel):
583
+ """
584
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
585
+ softmax) e.g. for RocStories/SWAG tasks.
586
+ """
587
+
588
+ def __init__(self, config):
589
+ super().__init__(config)
590
+ self.num_labels = config.num_labels
591
+ self.config = config
592
+
593
+ self.bert = BertModel(config)
594
+ classifier_dropout = (
595
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
596
+ )
597
+ self.dropout = nn.Dropout(classifier_dropout)
598
+
599
+ # In multiple choice tasks, all choices are submitted in a batch, and
600
+ # we compute a logit for each option independently. The logits are then
601
+ # normalized in the forward pass to get a probability distribution over
602
+ # the choices.
603
+ self.classifier = nn.Linear(config.hidden_size, 1)
604
+
605
+ # Initialize weights and apply final processing
606
+ self.post_init()
607
+
608
+ @classmethod
609
+ def from_composer(
610
+ cls,
611
+ pretrained_checkpoint,
612
+ state_dict=None,
613
+ cache_dir=None,
614
+ from_tf=False,
615
+ config=None,
616
+ *inputs,
617
+ **kwargs,
618
+ ):
619
+ """Load from pre-trained."""
620
+ model = cls(config, *inputs, **kwargs)
621
+ if from_tf:
622
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
623
+
624
+ state_dict = torch.load(pretrained_checkpoint)
625
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
626
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
627
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
628
+
629
+ if len(missing_keys) > 0:
630
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
631
+ if len(unexpected_keys) > 0:
632
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
633
+
634
+ return model
635
+
636
+ def forward(
637
+ self,
638
+ input_ids: Optional[torch.Tensor] = None,
639
+ attention_mask: Optional[torch.Tensor] = None,
640
+ token_type_ids: Optional[torch.Tensor] = None,
641
+ position_ids: Optional[torch.Tensor] = None,
642
+ head_mask: Optional[torch.Tensor] = None,
643
+ inputs_embeds: Optional[torch.Tensor] = None,
644
+ labels: Optional[torch.Tensor] = None,
645
+ output_attentions: Optional[bool] = None,
646
+ output_hidden_states: Optional[bool] = None,
647
+ return_dict: Optional[bool] = None,
648
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
649
+ r"""
650
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
651
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
652
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
653
+ `input_ids` above)
654
+ """
655
+
656
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
657
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
658
+
659
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
660
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
661
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
662
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
663
+ inputs_embeds = (
664
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
665
+ if inputs_embeds is not None
666
+ else None
667
+ )
668
+
669
+ outputs = self.bert(
670
+ input_ids,
671
+ attention_mask=attention_mask,
672
+ token_type_ids=token_type_ids,
673
+ position_ids=position_ids,
674
+ head_mask=head_mask,
675
+ inputs_embeds=inputs_embeds,
676
+ output_attentions=output_attentions,
677
+ output_hidden_states=output_hidden_states,
678
+ return_dict=return_dict,
679
+ )
680
+
681
+ pooled_output = outputs[1]
682
+
683
+ pooled_output = self.dropout(pooled_output)
684
+ logits = self.classifier(pooled_output)
685
+ reshaped_logits = logits.view(-1, num_choices)
686
+
687
+ loss = None
688
+ if labels is not None:
689
+ loss_fct = nn.CrossEntropyLoss()
690
+ loss = loss_fct(reshaped_logits, labels)
691
+
692
+ if not return_dict:
693
+ output = (reshaped_logits,) + outputs[2:]
694
+ return ((loss,) + output) if loss is not None else output
695
+
696
+ return MultipleChoiceModelOutput(
697
+ loss=loss,
698
+ logits=reshaped_logits,
699
+ hidden_states=None,
700
+ attentions=None,
701
+ )
702
+
703
+
704
+ class BertForTokenClassification(BertPreTrainedModel):
705
+ # TBD: Push in future commit
706
+ pass
707
+
708
+
709
+ class BertForQuestionAnswering(BertPreTrainedModel):
710
+ """Bert Model with a span classification head.
711
+
712
+ This is used for extractive question-answering tasks like SQuAD (a linear
713
+ layers on top of the hidden states' output to compute `span start logits`
714
+ and `span end logits`).
715
+ """
716
+
717
+ # TBD: Push in future commit
718
+
719
+
720
+ ###################
721
+ # FlexBert Heads
722
+ ###################
723
+
724
+
725
+ class FlexBertPredictionHead(nn.Module):
726
+ def __init__(self, config: FlexBertConfig):
727
+ super().__init__()
728
+ self.config = config
729
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_pred_bias)
730
+ self.act = get_act_fn(config.head_pred_act) if config.head_pred_act else nn.Identity()
731
+ self.norm = (
732
+ get_norm_layer(config, compiled_norm=config.compile_model) if config.head_pred_norm else nn.Identity()
733
+ )
734
+
735
+ def _init_weights(self, reset_params: bool = False):
736
+ if reset_params:
737
+ self.norm.reset_parameters()
738
+ init_weights(self.config, self.dense, layer_dim=self.config.hidden_size, type_of_module=ModuleType.in_module)
739
+
740
+ def reset_parameters(self):
741
+ self._init_weights(reset_params=True)
742
+
743
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
744
+ return self.norm(self.act(self.dense(hidden_states)))
745
+
746
+
747
+ class FlexBertPoolingHead(nn.Module):
748
+ def __init__(self, config: FlexBertConfig):
749
+ super().__init__()
750
+ self.config = config
751
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_class_bias)
752
+ self.act = get_act_fn(config.head_class_act) if config.head_class_act else nn.Identity()
753
+ self.norm = get_norm_layer(config) if config.head_class_norm else nn.Identity()
754
+ self.drop = torch.nn.Dropout(config.head_class_dropout) if config.head_class_dropout > 0 else nn.Identity()
755
+ self.pooling_type = config.pooling_type
756
+
757
+ def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor:
758
+ if pool:
759
+ if self.pooling_type == "cls":
760
+ output = hidden_states[:, 0]
761
+ elif self.pooling_type == "mean":
762
+ output = hidden_states.mean(dim=1)
763
+ elif self.pooling_type == "max":
764
+ output = hidden_states.max(dim=1)[0]
765
+ else:
766
+ output = hidden_states
767
+
768
+ return self.drop(self.norm(self.act(self.dense(output))))
769
+
770
+ def _init_weights(self, reset_params: bool = False):
771
+ init_weights(self.config, self.dense, self.config.hidden_size, type_of_module=ModuleType.out_module)
772
+ if reset_params and hasattr(self.norm, "reset_parameters"):
773
+ self.norm.reset_parameters()
774
+
775
+ def reset_parameters(self):
776
+ self._init_weights(reset_params=True)
777
+
778
+
779
+ ###################
780
+ # FlexBert Models
781
+ ###################
782
+
783
+
784
+ @dataclass
785
+ class MaskedLMOutput(ModelOutput):
786
+ """
787
+ Base class for masked language models outputs.
788
+
789
+ Args:
790
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
791
+ Masked language modeling (MLM) loss.
792
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
793
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
794
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
795
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
796
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
797
+
798
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
799
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
800
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
801
+ sequence_length)`.
802
+
803
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
804
+ heads.
805
+ """
806
+
807
+ loss: Optional[torch.FloatTensor] = None
808
+ logits: torch.FloatTensor = None
809
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
810
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
811
+ indices: Optional[torch.LongTensor] = None
812
+ cu_seqlens: Optional[torch.LongTensor] = None
813
+ max_seqlen: Optional[int] = None
814
+ batch_size: Optional[int] = None
815
+ seq_len: Optional[int] = None
816
+ labels: Optional[torch.LongTensor] = None
817
+
818
+
819
+ @dataclass
820
+ class MaskedLMOutputZLoss(ModelOutput):
821
+ """
822
+ Base class for masked language models outputs.
823
+
824
+ Args:
825
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
826
+ Masked language modeling (MLM) loss.
827
+ ce_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
828
+ Cross entropy loss.
829
+ z_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
830
+ Z loss.
831
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
832
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
833
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
834
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
835
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
836
+
837
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
838
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
839
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
840
+ sequence_length)`.
841
+
842
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
843
+ heads.
844
+ indices (`torch.LongTensor` of shape `(batch_size,)`):
845
+ Indices of the tokens to be masked.
846
+ """
847
+
848
+ loss: Optional[torch.FloatTensor] = None
849
+ ce_loss: Optional[torch.FloatTensor] = None
850
+ z_loss: Optional[torch.FloatTensor] = None
851
+ logits: torch.FloatTensor = None
852
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
853
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
854
+ indices: Optional[torch.LongTensor] = None
855
+ cu_seqlens: Optional[torch.LongTensor] = None
856
+ max_seqlen: Optional[int] = None
857
+ batch_size: Optional[int] = None
858
+ seq_len: Optional[int] = None
859
+ labels: Optional[torch.LongTensor] = None
860
+
861
+
862
+ class FlexBertPreTrainedModel(BertPreTrainedModel):
863
+ """
864
+ An abstract class to handle custom weights initialization of modules
865
+ """
866
+
867
+ def _init_module_weights(self, module: nn.Module):
868
+ """
869
+ Custom weight init of modules using initialization.init_weights
870
+ Currently only supports init of embedding modules
871
+ """
872
+ assert isinstance(module, nn.Module)
873
+ if isinstance(module, nn.Embedding):
874
+ init_weights(self.config, module, type_of_module=ModuleType.emb)
875
+ else:
876
+ print(module)
877
+ print("Custom weight init for the given module is not supported, please fix")
878
+ # raise NotImplementedError("Custom weight init for the given module is not supported")
879
+
880
+
881
+ class FlexBertModel(FlexBertPreTrainedModel):
882
+ """Overall BERT model.
883
+
884
+ Args:
885
+ config: a BertConfig class instance with the configuration to build a new model
886
+
887
+ Inputs:
888
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
889
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
890
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
891
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
892
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
893
+ a `sentence B` token (see BERT paper for more details).
894
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
895
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
896
+ input sequence length in the current batch. It's the mask that we typically use for attention when
897
+ a batch has varying length sentences.
898
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
899
+
900
+ Outputs: Tuple of (encoded_layers, pooled_output)
901
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
902
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
903
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
904
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
905
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
906
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
907
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
908
+ classifier pretrained on top of the hidden state associated to the first character of the
909
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
910
+
911
+ Example usage:
912
+ ```python
913
+ # Already been converted into WordPiece token ids
914
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
915
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
916
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
917
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
918
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
919
+ model = BertModel(config=config)
920
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
921
+ ```
922
+ """
923
+
924
+ def __init__(self, config: FlexBertConfig):
925
+ super().__init__(config)
926
+ self.embeddings = get_embedding_layer(config)
927
+ self.encoder = get_encoder_layer(config)
928
+ if config.final_norm:
929
+ # if we use prenorm attention we need to add a final norm
930
+ self.final_norm = get_norm_layer(config)
931
+ else:
932
+ self.final_norm = None
933
+ self.unpad_embeddings = config.unpad_embeddings
934
+
935
+ def post_init(self):
936
+ self._init_weights(reset_params=False)
937
+ self._backward_compatibility_gradient_checkpointing()
938
+
939
+ def get_input_embeddings(self):
940
+ return self.embeddings.tok_embeddings
941
+
942
+ def set_input_embeddings(self, value):
943
+ self.embeddings.tok_embeddings = value
944
+
945
+ def forward(
946
+ self,
947
+ input_ids: torch.Tensor,
948
+ attention_mask: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.Tensor] = None,
950
+ indices: Optional[torch.Tensor] = None,
951
+ cu_seqlens: Optional[torch.Tensor] = None,
952
+ max_seqlen: Optional[int] = None,
953
+ **kwargs,
954
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
955
+ if attention_mask is None:
956
+ attention_mask = torch.ones_like(input_ids)
957
+
958
+ embedding_output = self.embeddings(input_ids, position_ids)
959
+
960
+ encoder_outputs = self.encoder(
961
+ hidden_states=embedding_output,
962
+ attention_mask=attention_mask,
963
+ indices=indices,
964
+ cu_seqlens=cu_seqlens,
965
+ max_seqlen=max_seqlen,
966
+ )
967
+
968
+ if self.final_norm is not None:
969
+ encoder_outputs = self.final_norm(encoder_outputs)
970
+ return encoder_outputs
971
+
972
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
973
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
974
+ if module:
975
+ self._init_module_weights(module)
976
+ else:
977
+ assert isinstance(reset_params, bool)
978
+ self.embeddings._init_weights(reset_params=reset_params)
979
+ self.encoder._init_weights(reset_params=reset_params)
980
+
981
+ if reset_params and self.config.final_norm:
982
+ self.final_norm.reset_parameters()
983
+
984
+ def reset_parameters(self):
985
+ self._init_weights(reset_params=True)
986
+
987
+ def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
988
+ """Returns the number of parameters in the model.
989
+
990
+ Args:
991
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
992
+ trainable: only count trainable parameters.
993
+ """
994
+ params = sum([_count_parameters(layer, trainable) for layer in self.encoder.layers])
995
+ if count_embeddings:
996
+ params += _count_parameters(self.embeddings, trainable)
997
+ if hasattr(self.embeddings, "position_embeddings"):
998
+ params -= _count_parameters(self.embeddings.position_embeddings, trainable)
999
+ return params
1000
+
1001
+
1002
+ class FlexBertForMaskedLM(FlexBertPreTrainedModel):
1003
+ def __init__(self, config: FlexBertConfig):
1004
+ super().__init__(config)
1005
+ self.bert = FlexBertModel(config)
1006
+ self.head = FlexBertPredictionHead(config)
1007
+
1008
+ if config.tie_word_embeddings:
1009
+ decoder_weights = self.bert.embeddings.tok_embeddings.weight
1010
+ else:
1011
+ decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight
1012
+ self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias)
1013
+ self.decoder.weight = decoder_weights
1014
+
1015
+ self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy"
1016
+ self.return_z_loss = config.loss_kwargs.get("return_z_loss", False)
1017
+ self.unpad_embeddings = config.unpad_embeddings
1018
+ self.pad_logits = config.pad_logits
1019
+ self.compile_model = config.compile_model
1020
+ self.masked_prediction = config.masked_prediction
1021
+
1022
+ # Initialize weights and apply final processing
1023
+ self._init_weights(reset_params=False)
1024
+
1025
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1026
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1027
+ if module:
1028
+ self._init_module_weights(module)
1029
+ else:
1030
+ assert isinstance(reset_params, bool)
1031
+ self.bert._init_weights(reset_params=reset_params)
1032
+ self.head._init_weights(reset_params=reset_params)
1033
+
1034
+ # Output weights.
1035
+ if not self.config.tie_word_embeddings:
1036
+ init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1037
+
1038
+ @classmethod
1039
+ def from_composer(
1040
+ cls,
1041
+ pretrained_checkpoint,
1042
+ state_dict=None,
1043
+ cache_dir=None,
1044
+ from_tf=False,
1045
+ config=None,
1046
+ *inputs,
1047
+ **kwargs,
1048
+ ):
1049
+ """Load from pre-trained."""
1050
+ model = cls(config, *inputs, **kwargs)
1051
+ if from_tf:
1052
+ raise ValueError("FlexBERT does not support loading TensorFlow weights.")
1053
+
1054
+ state_dict = torch.load(pretrained_checkpoint)
1055
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1056
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1057
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1058
+
1059
+ if len(missing_keys) > 0:
1060
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1061
+ if len(unexpected_keys) > 0:
1062
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1063
+
1064
+ return model
1065
+
1066
+ def get_output_embeddings(self):
1067
+ return self.decoder
1068
+
1069
+ def set_output_embeddings(self, new_embeddings):
1070
+ self.decoder = new_embeddings
1071
+
1072
+ @torch.no_grad()
1073
+ def unpad_inputs(
1074
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, labels: torch.Tensor
1075
+ ):
1076
+ return unpad_input(input_ids, attention_mask, position_ids, labels)
1077
+
1078
+ @torch.no_grad()
1079
+ def pad_inputs(
1080
+ self,
1081
+ inputs: torch.Tensor,
1082
+ indices: torch.Tensor,
1083
+ batch_size: int,
1084
+ seqlen: int,
1085
+ labels: Optional[torch.Tensor] = None,
1086
+ ignore_index: int = -100,
1087
+ ):
1088
+ return pad_input(
1089
+ inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index
1090
+ )
1091
+
1092
+ @torch.compile(dynamic=True)
1093
+ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
1094
+ return self.decoder(self.head(output))
1095
+
1096
+ def forward(
1097
+ self,
1098
+ input_ids: Optional[torch.Tensor],
1099
+ attention_mask: Optional[torch.Tensor] = None,
1100
+ position_ids: Optional[torch.Tensor] = None,
1101
+ labels: Optional[torch.Tensor] = None,
1102
+ return_dict: Optional[bool] = None,
1103
+ indices: Optional[torch.Tensor] = None,
1104
+ cu_seqlens: Optional[torch.Tensor] = None,
1105
+ max_seqlen: Optional[int] = None,
1106
+ batch_size: Optional[int] = None,
1107
+ seq_len: Optional[int] = None,
1108
+ **kwargs,
1109
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1110
+ # labels should be a `torch.LongTensor` of shape
1111
+ # `(batch_size, sequence_length)`. These are used for computing the
1112
+ # masked language modeling loss.
1113
+ #
1114
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
1115
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
1116
+ # (masked), the loss is only computed for the tokens with labels in `[0,
1117
+ # ..., config.vocab_size]`
1118
+ #
1119
+ # Prediction scores are only computed for masked tokens and the (bs,
1120
+ # seqlen) dimensions are flattened
1121
+
1122
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1123
+
1124
+ if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1125
+ batch_size, seq_len = input_ids.shape[:2]
1126
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
1127
+ input_ids, attention_mask, position_ids, labels
1128
+ )
1129
+
1130
+ output = self.bert(
1131
+ input_ids,
1132
+ attention_mask=attention_mask,
1133
+ position_ids=position_ids,
1134
+ indices=indices,
1135
+ cu_seqlens=cu_seqlens,
1136
+ max_seqlen=max_seqlen,
1137
+ )
1138
+
1139
+ if self.masked_prediction and labels is not None:
1140
+ # flatten labels and output first
1141
+ labels = labels.view(-1)
1142
+ output = output.view(labels.shape[0], -1)
1143
+
1144
+ # then filter out the non-masked tokens
1145
+ mask_tokens = labels != self.loss_fn.ignore_index
1146
+ output = output[mask_tokens]
1147
+ labels = labels[mask_tokens]
1148
+
1149
+ if self.compile_model:
1150
+ logits = self.compiled_head(output)
1151
+ else:
1152
+ logits = self.decoder(self.head(output))
1153
+
1154
+ loss = None
1155
+ if labels is not None:
1156
+ if not self.masked_prediction:
1157
+ labels = labels.view(-1)
1158
+ logits = logits.view(labels.shape[0], -1)
1159
+
1160
+ if self.return_z_loss:
1161
+ loss, z_loss = self.loss_fn(logits, labels)
1162
+ if self.pad_logits:
1163
+ return MaskedLMOutputZLoss(
1164
+ loss=loss,
1165
+ ce_loss=loss.detach().clone() - z_loss,
1166
+ z_loss=z_loss,
1167
+ logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1168
+ hidden_states=None,
1169
+ attentions=None,
1170
+ )
1171
+ else:
1172
+ return MaskedLMOutputZLoss(
1173
+ loss=loss,
1174
+ ce_loss=loss.detach().clone() - z_loss,
1175
+ z_loss=z_loss,
1176
+ logits=logits,
1177
+ hidden_states=None,
1178
+ attentions=None,
1179
+ indices=indices,
1180
+ cu_seqlens=cu_seqlens,
1181
+ max_seqlen=max_seqlen,
1182
+ batch_size=batch_size,
1183
+ seq_len=seq_len,
1184
+ labels=labels,
1185
+ )
1186
+ else:
1187
+ loss = self.loss_fn(logits, labels)
1188
+
1189
+ if self.pad_logits:
1190
+ return MaskedLMOutput(
1191
+ loss=loss,
1192
+ logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1193
+ hidden_states=None,
1194
+ attentions=None,
1195
+ )
1196
+ else:
1197
+ return MaskedLMOutput(
1198
+ loss=loss,
1199
+ logits=logits,
1200
+ hidden_states=None,
1201
+ attentions=None,
1202
+ indices=indices,
1203
+ cu_seqlens=cu_seqlens,
1204
+ max_seqlen=max_seqlen,
1205
+ batch_size=batch_size,
1206
+ seq_len=seq_len,
1207
+ labels=labels,
1208
+ )
1209
+
1210
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs):
1211
+ input_shape = input_ids.shape
1212
+ effective_batch_size = input_shape[0]
1213
+
1214
+ # add a dummy token
1215
+ if self.config.pad_token_id is None:
1216
+ raise ValueError("The PAD token should be defined for generation")
1217
+
1218
+ attention_mask = torch.cat(
1219
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
1220
+ dim=-1,
1221
+ )
1222
+ dummy_token = torch.full(
1223
+ (effective_batch_size, 1),
1224
+ self.config.pad_token_id,
1225
+ dtype=torch.long,
1226
+ device=input_ids.device,
1227
+ )
1228
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1229
+
1230
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1231
+
1232
+ def get_number_parameters(
1233
+ self, count_embeddings: bool = True, count_decoder: bool = False, trainable: bool = True
1234
+ ) -> int:
1235
+ """Returns the number of parameters in the model.
1236
+
1237
+ Args:
1238
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1239
+ count_decoder: count the parameters in the decoder layer if weights are not tied.
1240
+ trainable: only count trainable parameters.
1241
+ """
1242
+ params = self.bert.get_number_parameters(count_embeddings, trainable)
1243
+ params += _count_parameters(self.head, trainable)
1244
+ if count_decoder and not self.config.tie_word_embeddings:
1245
+ params += _count_parameters(self.decoder, trainable)
1246
+ return params
1247
+
1248
+
1249
+ class FlexBertForSequenceClassification(FlexBertPreTrainedModel):
1250
+ """Bert Model transformer with a sequence classification/regression head.
1251
+
1252
+ This head is just a linear layer on top of the pooled output. Used for,
1253
+ e.g., GLUE tasks.
1254
+ """
1255
+
1256
+ def __init__(self, config: FlexBertConfig):
1257
+ super().__init__(config)
1258
+ self.num_labels = config.num_labels
1259
+ self.config = config
1260
+
1261
+ self.bert = FlexBertModel(config)
1262
+ self.head = FlexBertPoolingHead(config)
1263
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1264
+
1265
+ # Initialize weights and apply final processing
1266
+ self._init_weights(reset_params=False)
1267
+
1268
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1269
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1270
+ if module:
1271
+ self._init_module_weights(module)
1272
+ else:
1273
+ assert isinstance(reset_params, bool)
1274
+ self.bert._init_weights(reset_params=reset_params)
1275
+ self.head._init_weights(reset_params=reset_params)
1276
+ init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out)
1277
+
1278
+ @classmethod
1279
+ def from_composer(
1280
+ cls,
1281
+ pretrained_checkpoint,
1282
+ state_dict=None,
1283
+ cache_dir=None,
1284
+ from_tf=False,
1285
+ config=None,
1286
+ *inputs,
1287
+ **kwargs,
1288
+ ):
1289
+ """Load from pre-trained."""
1290
+ model = cls(config, *inputs, **kwargs)
1291
+ if from_tf:
1292
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
1293
+
1294
+ state_dict = torch.load(pretrained_checkpoint)
1295
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1296
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1297
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1298
+
1299
+ if len(missing_keys) > 0:
1300
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1301
+ if len(unexpected_keys) > 0:
1302
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1303
+
1304
+ return model
1305
+
1306
+ def forward(
1307
+ self,
1308
+ input_ids: Optional[torch.Tensor] = None,
1309
+ attention_mask: Optional[torch.Tensor] = None,
1310
+ position_ids: Optional[torch.Tensor] = None,
1311
+ labels: Optional[torch.Tensor] = None,
1312
+ return_dict: Optional[bool] = None,
1313
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1314
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1315
+ # Labels for computing the sequence classification/regression loss.
1316
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
1317
+ # If `config.num_labels == 1` a regression loss is computed
1318
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
1319
+ # is computed (cross-entropy).
1320
+
1321
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1322
+
1323
+ output = self.bert(
1324
+ input_ids,
1325
+ attention_mask=attention_mask,
1326
+ position_ids=position_ids,
1327
+ )
1328
+
1329
+ pooled_output = self.head(output)
1330
+ logits = self.classifier(pooled_output)
1331
+
1332
+ loss = None
1333
+ if labels is not None:
1334
+ # Compute loss
1335
+ if self.config.problem_type is None:
1336
+ if self.num_labels == 1:
1337
+ self.config.problem_type = "regression"
1338
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1339
+ self.config.problem_type = "single_label_classification"
1340
+ else:
1341
+ self.config.problem_type = "multi_label_classification"
1342
+
1343
+ if self.config.problem_type == "regression":
1344
+ loss_fct = nn.MSELoss()
1345
+ if self.num_labels == 1:
1346
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1347
+ else:
1348
+ loss = loss_fct(logits, labels)
1349
+ elif self.config.problem_type == "single_label_classification":
1350
+ loss_fct = nn.CrossEntropyLoss()
1351
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1352
+ elif self.config.problem_type == "multi_label_classification":
1353
+ loss_fct = nn.BCEWithLogitsLoss()
1354
+ loss = loss_fct(logits, labels)
1355
+
1356
+ if not return_dict:
1357
+ output = (logits,) + output
1358
+ return ((loss,) + output) if loss is not None else output
1359
+
1360
+ return SequenceClassifierOutput(
1361
+ loss=loss,
1362
+ logits=logits,
1363
+ hidden_states=None,
1364
+ attentions=None,
1365
+ )
1366
+
1367
+ def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1368
+ """Returns the number of parameters in the model.
1369
+
1370
+ Args:
1371
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1372
+ trainable: only count trainable parameters.
1373
+ """
1374
+ params = self.bert.get_number_parameters(count_embeddings, trainable)
1375
+ params += _count_parameters(self.head, trainable)
1376
+ params += _count_parameters(self.classifier, trainable)
1377
+ return params
1378
+
1379
+
1380
+ class FlexBertForMultipleChoice(FlexBertPreTrainedModel):
1381
+ """
1382
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1383
+ softmax) e.g. for RocStories/SWAG tasks.
1384
+ """
1385
+
1386
+ def __init__(self, config: FlexBertConfig):
1387
+ super().__init__(config)
1388
+ self.num_labels = config.num_labels
1389
+ self.config = config
1390
+
1391
+ self.bert = FlexBertModel(config)
1392
+ self.head = FlexBertPoolingHead(config)
1393
+
1394
+ # In multiple choice tasks, all choices are submitted in a batch, and
1395
+ # we compute a logit for each option independently. The logits are then
1396
+ # normalized in the forward pass to get a probability distribution over
1397
+ # the choices.
1398
+ self.classifier = nn.Linear(config.hidden_size, 1)
1399
+
1400
+ # Initialize weights and apply final processing
1401
+ self._init_weights(reset_params=False)
1402
+
1403
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1404
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1405
+ if module:
1406
+ self._init_module_weights(module)
1407
+ else:
1408
+ assert isinstance(reset_params, bool)
1409
+ self.bert._init_weights(reset_params=reset_params)
1410
+ self.head._init_weights(reset_params=reset_params)
1411
+ init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out)
1412
+
1413
+ @classmethod
1414
+ def from_composer(
1415
+ cls,
1416
+ pretrained_checkpoint,
1417
+ state_dict=None,
1418
+ cache_dir=None,
1419
+ from_tf=False,
1420
+ config=None,
1421
+ *inputs,
1422
+ **kwargs,
1423
+ ):
1424
+ """Load from pre-trained."""
1425
+ model = cls(config, *inputs, **kwargs)
1426
+ if from_tf:
1427
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
1428
+
1429
+ state_dict = torch.load(pretrained_checkpoint)
1430
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1431
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1432
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1433
+
1434
+ if len(missing_keys) > 0:
1435
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1436
+ if len(unexpected_keys) > 0:
1437
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1438
+
1439
+ return model
1440
+
1441
+ def forward(
1442
+ self,
1443
+ input_ids: Optional[torch.Tensor] = None,
1444
+ attention_mask: Optional[torch.Tensor] = None,
1445
+ position_ids: Optional[torch.Tensor] = None,
1446
+ labels: Optional[torch.Tensor] = None,
1447
+ return_dict: Optional[bool] = None,
1448
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1449
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1450
+ # Labels for computing the sequence classification/regression loss.
1451
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
1452
+ # If `config.num_labels == 1` a regression loss is computed
1453
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
1454
+ # is computed (cross-entropy).
1455
+
1456
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1457
+ num_choices = input_ids.shape[1]
1458
+
1459
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1460
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1461
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1462
+
1463
+ output = self.bert(
1464
+ input_ids,
1465
+ attention_mask=attention_mask,
1466
+ position_ids=position_ids,
1467
+ )
1468
+
1469
+ pooled_output = self.head(output)
1470
+ logits = self.classifier(pooled_output)
1471
+ reshaped_logits = logits.view(-1, num_choices)
1472
+
1473
+ loss = None
1474
+ if labels is not None:
1475
+ loss_fct = nn.CrossEntropyLoss()
1476
+ loss = loss_fct(reshaped_logits, labels)
1477
+
1478
+ if not return_dict:
1479
+ output = (reshaped_logits,) + output
1480
+ return ((loss,) + output) if loss is not None else output
1481
+
1482
+ return MultipleChoiceModelOutput(
1483
+ loss=loss,
1484
+ logits=reshaped_logits,
1485
+ hidden_states=None,
1486
+ attentions=None,
1487
+ )
1488
+
1489
+ def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1490
+ """Returns the number of parameters in the model.
1491
+
1492
+ Args:
1493
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1494
+ trainable: only count trainable parameters.
1495
+ """
1496
+ params = self.bert.get_number_parameters(count_embeddings, trainable)
1497
+ params += _count_parameters(self.head, trainable)
1498
+ params += _count_parameters(self.classifier, trainable)
1499
+ return params
1500
+
1501
+
1502
+ def init_model_from_pretrained(
1503
+ pretrained_model: FlexBertModel,
1504
+ new_model: FlexBertModel,
1505
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
1506
+ ):
1507
+ """
1508
+ Initialize the new model from the pretrained model.
1509
+
1510
+ This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`.
1511
+ The new model must have the same or more layers and the same or larger dimensions than the pretrained model.
1512
+
1513
+ Args:
1514
+ pretrained_model (FlexBertModel): The smaller, pre-trained model
1515
+ new_model (FlexBertModel): The larger model to be initialized
1516
+ mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
1517
+
1518
+ This function assumes that the new_model has more layers and a larger hidden size
1519
+ than the pretrained_model, but the same vocabulary size.
1520
+ """
1521
+
1522
+ # Tile embeddings
1523
+ assert isinstance(
1524
+ new_model.embeddings, type(pretrained_model.embeddings)
1525
+ ), f"Pretrained and new_model layers must be the same type, got {type(new_model.embeddings)} and {type(pretrained_model.embeddings)}"
1526
+ assert isinstance(
1527
+ new_model.embeddings,
1528
+ (FlexBertAbsoluteEmbeddings, FlexBertSansPositionEmbeddings, FlexBertCompiledSansPositionEmbeddings),
1529
+ ), f"Unsupported embedding layer type: {type(new_model.embeddings)}"
1530
+
1531
+ tile_embedding(pretrained_model.embeddings.tok_embeddings, new_model.embeddings.tok_embeddings, mode=mode)
1532
+ if isinstance(pretrained_model.embeddings, FlexBertAbsoluteEmbeddings):
1533
+ tile_embedding(pretrained_model.embeddings.pos_embeddings, new_model.embeddings.pos_embeddings, mode=mode)
1534
+
1535
+ if hasattr(pretrained_model.embeddings, "norm"):
1536
+ tile_norm(pretrained_model.embeddings.norm, new_model.embeddings.norm, mode=mode)
1537
+
1538
+ # Tile encoder layers
1539
+ assert isinstance(
1540
+ pretrained_model.encoder, (FlexBertUnpadEncoder, FlexBertPaddedEncoder)
1541
+ ), f"Unsupported encoder layer type: {type(pretrained_model.encoder)}"
1542
+ assert isinstance(
1543
+ new_model.encoder, type(pretrained_model.encoder)
1544
+ ), f"Pretrained and new_model encoder layers must be the same type, got {type(new_model.encoder)} and {type(pretrained_model.encoder)}"
1545
+
1546
+ # Calculate the layer mapping
1547
+ pretrained_layers = len(pretrained_model.encoder.layers)
1548
+ new_layers = len(new_model.encoder.layers)
1549
+ layer_mapping = [round(i * pretrained_layers / new_layers) for i in range(new_layers)]
1550
+
1551
+ # Initialize layers
1552
+ for new_model_idx, pretrained_idx in enumerate(layer_mapping):
1553
+ new_model_layer = new_model.encoder.layers[new_model_idx]
1554
+ pretrained_layer = pretrained_model.encoder.layers[pretrained_idx]
1555
+
1556
+ # first tile the PreNorm/PostNorm layers
1557
+ assert isinstance(
1558
+ new_model_layer, type(pretrained_layer)
1559
+ ), f"Pretrained and new_model prenorm/postnorm layers must be the same type, got {type(new_model_layer)} and {type(pretrained_layer)}"
1560
+ assert isinstance(
1561
+ new_model_layer,
1562
+ (
1563
+ FlexBertUnpadPreNormLayer,
1564
+ FlexBertCompileUnpadPreNormLayer,
1565
+ FlexBertUnpadParallelPreNormLayer,
1566
+ FlexBertUnpadPostNormLayer,
1567
+ FlexBertPaddedPreNormLayer,
1568
+ FlexBertPaddedParallelPreNormLayer,
1569
+ FlexBertPaddedPostNormLayer,
1570
+ ),
1571
+ ), f"Unsupported prenorm/postnorm layer type: {type(new_model_layer)}"
1572
+
1573
+ # First tile the normalization layers
1574
+ if hasattr(pretrained_layer, "attn_norm"):
1575
+ tile_norm(pretrained_layer.attn_norm, new_model_layer.attn_norm, mode=mode)
1576
+ if hasattr(pretrained_layer, "norm"):
1577
+ tile_norm(pretrained_layer.norm, new_model_layer.norm, mode=mode)
1578
+ if hasattr(pretrained_layer, "mlp_norm"):
1579
+ tile_norm(pretrained_layer.mlp_norm, new_model_layer.mlp_norm, mode=mode)
1580
+
1581
+ # Then tile the attention & mlp layers
1582
+ assert isinstance(
1583
+ new_model_layer.attn, type(pretrained_layer.attn)
1584
+ ), f"Pretrained and new_model attention layers must be the same type, got {type(new_model_layer.attn)} and {type(pretrained_layer.attn)}"
1585
+
1586
+ # first try the parallel attention layers
1587
+ if isinstance(pretrained_layer, (FlexBertUnpadParallelPreNormLayer, FlexBertPaddedParallelPreNormLayer)):
1588
+ assert isinstance(
1589
+ pretrained_layer.attn,
1590
+ (
1591
+ FlexBertUnpadParallelAttention,
1592
+ FlexBertPaddedParallelAttention,
1593
+ FlexBertUnpadRopeParallelAttention,
1594
+ FlexBertPaddedRopeParallelAttention,
1595
+ ),
1596
+ ), f"Parallel prenorm layer must have parallel attention layer: {type(pretrained_layer.attn)}"
1597
+ if not isinstance(pretrained_layer.mlp, (FlexBertParallelGLU)):
1598
+ raise ValueError(f"Parallel prenorm layer must have parallel MLP layer: {type(pretrained_layer.mlp)}")
1599
+ tile_linear(
1600
+ pretrained_layer.Wqkvff,
1601
+ new_model_layer.Wqkvff,
1602
+ linear_type=TileLinear.wqkvff,
1603
+ mode=mode,
1604
+ pretrained_attn_size=pretrained_layer.attn_size,
1605
+ pretrained_mlp_size=pretrained_layer.mlp_size,
1606
+ new_attn_size=new_model_layer.attn_size,
1607
+ new_mlp_size=new_model_layer.mlp_size,
1608
+ wqkvff_is_glu=True,
1609
+ )
1610
+
1611
+ # then try the fused attention layers
1612
+ elif isinstance(
1613
+ pretrained_layer.attn,
1614
+ (
1615
+ FlexBertUnpadAttention,
1616
+ FlexBertPaddedAttention,
1617
+ FlexBertUnpadRopeAttention,
1618
+ FlexBertPaddedRopeAttention,
1619
+ ),
1620
+ ):
1621
+ tile_linear(pretrained_layer.attn.Wqkv, new_model_layer.attn.Wqkv, linear_type=TileLinear.wqkv, mode=mode)
1622
+ else:
1623
+ raise ValueError(f"Unsupported attention layer type: {type(pretrained_layer.attn)}")
1624
+
1625
+ # finally, tile the attention output layer
1626
+ tile_linear(pretrained_layer.attn.Wo, new_model_layer.attn.Wo, linear_type=TileLinear.default, mode=mode)
1627
+
1628
+ # tile the mlp layer if the model is not using parallel attention layers
1629
+ if not isinstance(pretrained_layer.mlp, (FlexBertMLP, FlexBertGLU, FlexBertParallelGLU)):
1630
+ raise ValueError(f"Unsupported MLP layer type: {type(pretrained_layer.mlp)}")
1631
+ assert isinstance(
1632
+ new_model_layer.mlp, type(pretrained_layer.mlp)
1633
+ ), f"Pretrained and new_model mlp layers must be the same type, got {type(new_model_layer.mlp)} and {type(pretrained_layer.mlp)}"
1634
+
1635
+ # already tiled the parallel glu layer if it exists, so only need to handle mlp & glu Wi
1636
+ if isinstance(pretrained_layer.mlp, FlexBertGLU):
1637
+ tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.glu, mode=mode)
1638
+ elif isinstance(pretrained_layer.mlp, FlexBertMLP):
1639
+ tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.default, mode=mode)
1640
+ # tile the output for both ParallelGLU and MLP/GLU
1641
+ tile_linear(pretrained_layer.mlp.Wo, new_model_layer.mlp.Wo, linear_type=TileLinear.default, mode=mode)
1642
+
1643
+
1644
+ def init_mlm_model_from_pretrained(
1645
+ config: FlexBertConfig,
1646
+ pretrained_model: FlexBertForMaskedLM,
1647
+ new_model: FlexBertForMaskedLM,
1648
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
1649
+ ):
1650
+ """
1651
+ Initialize the new model from the pretrained model.
1652
+
1653
+ This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`.
1654
+ The new model must have the same or more layers and the same or larger dimensions than the pretrained model.
1655
+
1656
+ Args:
1657
+ config (FlexBertConfig): The configuration of the new_model
1658
+ pretrained_model (FlexBertForMaskedLM): The smaller, pre-trained model
1659
+ new_model (FlexBertForMaskedLM): The larger model to be initialized from the pretrained model
1660
+ mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
1661
+
1662
+ This function assumes that the new_model has more layers and a larger hidden size
1663
+ than the pretrained_model, but the same vocabulary size.
1664
+ """
1665
+ init_model_from_pretrained(pretrained_model.bert, new_model.bert, mode=mode)
1666
+
1667
+ # TODO: uncomment this when the repo is turned into a pip installable package
1668
+ # if not isinstance(pretrained_model.head, FlexBertPredictionHead):
1669
+ # raise ValueError(f"Pretrained model must have a prediction head: {type(pretrained_model.head)}")
1670
+ # if not isinstance(new_model.head, FlexBertPredictionHead):
1671
+ # raise ValueError(f"New model must have a prediction head: {type(new_model.head)}")
1672
+
1673
+ # tile the prediction head
1674
+ tile_linear(pretrained_model.head.dense, new_model.head.dense, linear_type=TileLinear.default, mode=mode)
1675
+ tile_norm(pretrained_model.head.norm, new_model.head.norm, mode=mode)
1676
+
1677
+ # setup weight tying
1678
+ if config.tie_word_embeddings:
1679
+ new_model.decoder.weight = new_model.bert.embeddings.tok_embeddings.weight
1680
+ tile_linear(
1681
+ pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True
1682
+ )
1683
+ else:
1684
+ tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode)
normalization.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation)
5
+ # License: LLAMA 2 COMMUNITY LICENSE AGREEMENT
6
+
7
+
8
+ import inspect
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import init
12
+
13
+ from configuration_bert import FlexBertConfig
14
+
15
+ try:
16
+ from flash_attn.ops.triton.layer_norm import RMSNorm as TritonRMSNorm
17
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
18
+
19
+ except ImportError:
20
+ TritonRMSNorm = None
21
+ layer_norm_fn = None
22
+
23
+
24
+ class RMSNorm(nn.Module):
25
+ """Llama2 RMSNorm implementation"""
26
+
27
+ def __init__(self, dim: int, eps: float = 1e-5):
28
+ """
29
+ Initialize the RMSNorm normalization layer.
30
+
31
+ Args:
32
+ dim (int): The dimension of the input tensor.
33
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
34
+
35
+ Attributes:
36
+ eps (float): A small value added to the denominator for numerical stability.
37
+ weight (nn.Parameter): Learnable scaling parameter.
38
+
39
+ """
40
+ super().__init__()
41
+ self.eps = eps
42
+ self.weight = nn.Parameter(torch.ones(dim))
43
+
44
+ def _norm(self, x):
45
+ """
46
+ Apply the RMSNorm normalization to the input tensor.
47
+
48
+ Args:
49
+ x (torch.Tensor): The input tensor.
50
+
51
+ Returns:
52
+ torch.Tensor: The normalized tensor.
53
+
54
+ """
55
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
56
+
57
+ def forward(self, x):
58
+ """
59
+ Forward pass through the RMSNorm layer.
60
+
61
+ Args:
62
+ x (torch.Tensor): The input tensor.
63
+
64
+ Returns:
65
+ torch.Tensor: The output tensor after applying RMSNorm.
66
+
67
+ """
68
+ output = self._norm(x.float()).type_as(x)
69
+ return output * self.weight
70
+
71
+ def reset_parameters(self):
72
+ init.ones_(self.weight)
73
+
74
+
75
+ if layer_norm_fn is not None:
76
+
77
+ class TritonLayerNorm(nn.LayerNorm):
78
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
79
+ return layer_norm_fn(
80
+ x,
81
+ self.weight,
82
+ self.bias,
83
+ residual=residual,
84
+ eps=self.eps,
85
+ prenorm=prenorm,
86
+ residual_in_fp32=residual_in_fp32,
87
+ )
88
+ else:
89
+ TritonLayerNorm = None
90
+
91
+ NORM2CLS = {
92
+ "layernorm": nn.LayerNorm,
93
+ "triton_layernorm": TritonLayerNorm if TritonLayerNorm is not None else nn.LayerNorm,
94
+ "rmsnorm": RMSNorm,
95
+ "triton_rmsnorm": TritonRMSNorm if TritonRMSNorm is not None else RMSNorm,
96
+ }
97
+
98
+
99
+ def get_norm_layer(config: FlexBertConfig, compiled_norm: bool = False) -> nn.Module:
100
+ try:
101
+ if compiled_norm:
102
+ # Use non-Triton norms when compiling
103
+ if config.normalization.startswith("triton_"):
104
+ norm = config.normalization.replace("triton_", "")
105
+ else:
106
+ norm = config.normalization
107
+ else:
108
+ norm = config.normalization
109
+ signature = inspect.signature(NORM2CLS[norm])
110
+ if hasattr(config, "norm_kwargs"):
111
+ norm_kwargs = {k: v for k, v in config.norm_kwargs.items() if k in signature.parameters}
112
+ else:
113
+ norm_kwargs = {}
114
+ return NORM2CLS[norm](config.hidden_size, **norm_kwargs)
115
+ except KeyError:
116
+ raise ValueError(f"Invalid normalization layer type: {config.normalization}, must be one of {NORM2CLS.keys()}.")
options.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from normalization import NORM2CLS
2
+ from embeddings import EBB2CLS
3
+ from activation import ACT2CLS
4
+ from attention import ATTN2CLS
5
+ from mlp import MLP2CLS
6
+ from layers import LAYER2CLS
7
+
8
+
9
+ def print_layer_options():
10
+ print("Activation options:")
11
+ for option in ACT2CLS:
12
+ print(f" {option}")
13
+
14
+ print("\nAttention Layer options:")
15
+ for option in ATTN2CLS:
16
+ print(f" {option}")
17
+
18
+ print("\nEmbedding Layer options:")
19
+ for option in EBB2CLS:
20
+ print(f" {option}")
21
+
22
+ print("\nBert Layer options:")
23
+ for option in LAYER2CLS:
24
+ print(f" {option}")
25
+
26
+ print("\nMLP Layer options:")
27
+ for option in MLP2CLS:
28
+ print(f" {option}")
29
+
30
+ print("\nNormalization options:")
31
+ for option in NORM2CLS:
32
+ print(f" {option}")
padding.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from typing import Optional, Tuple
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def unpad_input(
8
+ inputs: Tensor,
9
+ attention_mask: Tensor,
10
+ position_ids: Optional[Tensor] = None,
11
+ labels: Optional[Tensor] = None,
12
+ ) -> Tuple[Tensor, Tensor, Tensor, int, Optional[Tensor], Optional[Tensor]]:
13
+ """
14
+ Remove padding from input sequences.
15
+
16
+ Args:
17
+ inputs: (batch, seqlen, ...) or (batch, seqlen)
18
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
19
+ position_ids: (batch, seqlen), int, position ids
20
+ labels: (batch, seqlen), int, labels
21
+
22
+ Returns:
23
+ unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
24
+ indices: (total_nnz)
25
+ cu_seqlens: (batch + 1), the cumulative sequence lengths
26
+ max_seqlen_in_batch: int
27
+ unpadded_position_ids: (total_nnz) or None
28
+ unpadded_labels: (total_nnz) or None
29
+ """
30
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
31
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
32
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
33
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
34
+
35
+ if inputs.dim() == 2:
36
+ unpadded_inputs = inputs.flatten()[indices]
37
+ else:
38
+ batch, seqlen, *rest = inputs.shape
39
+ shape = batch * seqlen
40
+ unpadded_inputs = inputs.view(shape, *rest)[indices]
41
+
42
+ unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
43
+ unpadded_labels = labels.flatten()[indices] if labels is not None else None
44
+
45
+ return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
46
+
47
+
48
+ def pad_input(
49
+ inputs: Tensor,
50
+ indices: Tensor,
51
+ batch: int,
52
+ seqlen: int,
53
+ labels: Optional[Tensor] = None,
54
+ ignore_index: int = -100,
55
+ ) -> Tuple[Tensor, Optional[Tensor]]:
56
+ """
57
+ Add padding to sequences.
58
+
59
+ Args:
60
+ inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
61
+ indices: (total_nnz)
62
+ batch: int, batch size
63
+ seqlen: int, max sequence length
64
+ position_ids: (total_nnz) or None
65
+ labels: (total_nnz) or None
66
+
67
+ Returns:
68
+ padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
69
+ padded_labels: (batch, seqlen) or None
70
+ """
71
+ if inputs.dim() == 1:
72
+ output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
73
+ output[indices] = inputs
74
+ padded_inputs = output.view(batch, seqlen)
75
+ else:
76
+ _, *rest = inputs.shape
77
+ output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
78
+ output[indices] = inputs
79
+ padded_inputs = output.view(batch, seqlen, *rest)
80
+
81
+ padded_labels = None
82
+ if labels is not None:
83
+ padded_labels = torch.full((batch * seqlen,), fill_value=ignore_index, dtype=labels.dtype, device=labels.device)
84
+ padded_labels[indices] = labels
85
+ padded_labels = padded_labels.view(batch, seqlen)
86
+
87
+ return padded_inputs, padded_labels
rotary.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright (c) 2023, Tri Dao.
5
+ # License: Apache-2.0
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from flash_attn.ops.triton.rotary import apply_rotary
10
+
11
+ from typing import Optional, Tuple, Union
12
+
13
+
14
+ class ApplyRotaryEmbUnpad(torch.autograd.Function):
15
+ @staticmethod
16
+ def forward(
17
+ ctx,
18
+ qkv,
19
+ cos,
20
+ sin,
21
+ interleaved=False,
22
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
23
+ cu_seqlens: Optional[torch.Tensor] = None,
24
+ max_seqlen: Optional[int] = None,
25
+ ):
26
+ # (total_nnz, 3, nheads, headdim)
27
+ total_nnz, three, nheads, headdim = qkv.shape
28
+ assert three == 3
29
+ if qkv.is_contiguous():
30
+ # Call 1 kernel instead of 2 kernels
31
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
32
+ # dimensions, we get the same tensor
33
+ # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
34
+ qk = qkv[:, :2].view(total_nnz, -1, headdim)
35
+ apply_rotary(
36
+ qk,
37
+ cos,
38
+ sin,
39
+ seqlen_offsets=seqlen_offsets,
40
+ cu_seqlens=cu_seqlens,
41
+ max_seqlen=max_seqlen,
42
+ interleaved=interleaved,
43
+ inplace=True,
44
+ )
45
+ else:
46
+ q, k = qkv[:, 0, :, :], qkv[:, 1, :, :]
47
+ apply_rotary(
48
+ q,
49
+ cos,
50
+ sin,
51
+ seqlen_offsets=seqlen_offsets,
52
+ cu_seqlens=cu_seqlens,
53
+ max_seqlen=max_seqlen,
54
+ interleaved=interleaved,
55
+ inplace=True,
56
+ )
57
+ apply_rotary(
58
+ k,
59
+ cos,
60
+ sin,
61
+ seqlen_offsets=seqlen_offsets,
62
+ cu_seqlens=cu_seqlens,
63
+ max_seqlen=max_seqlen,
64
+ interleaved=interleaved,
65
+ inplace=True,
66
+ )
67
+
68
+ if isinstance(seqlen_offsets, int):
69
+ ctx.save_for_backward(cos, sin, cu_seqlens)
70
+ ctx.seqlen_offsets = seqlen_offsets
71
+ else:
72
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
73
+ ctx.seqlen_offsets = None
74
+ ctx.interleaved = interleaved
75
+ ctx.max_seqlen = max_seqlen
76
+ return qkv
77
+
78
+ @staticmethod
79
+ def backward(ctx, do):
80
+ seqlen_offsets = ctx.seqlen_offsets
81
+ if seqlen_offsets is None:
82
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
83
+ else:
84
+ cos, sin, cu_seqlens = ctx.saved_tensors
85
+ if do.is_contiguous():
86
+ total_nnz, three, nheads, headdim = do.shape
87
+ # Call 1 kernel instead of 2 kernels
88
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
89
+ # dimensions, we get the same tensor
90
+ dqk = do[:, :2].view(total_nnz, -1, headdim)
91
+ apply_rotary(
92
+ dqk,
93
+ cos,
94
+ sin,
95
+ seqlen_offsets=seqlen_offsets,
96
+ cu_seqlens=cu_seqlens,
97
+ max_seqlen=ctx.max_seqlen,
98
+ interleaved=ctx.interleaved,
99
+ inplace=True,
100
+ conjugate=True,
101
+ )
102
+ else:
103
+ dq, dk = do[:, 0, :, :], do[:, 1, :, :]
104
+ apply_rotary(
105
+ dq,
106
+ cos,
107
+ sin,
108
+ seqlen_offsets=seqlen_offsets,
109
+ cu_seqlens=cu_seqlens,
110
+ max_seqlen=ctx.max_seqlen,
111
+ interleaved=ctx.interleaved,
112
+ inplace=True,
113
+ conjugate=True,
114
+ )
115
+ apply_rotary(
116
+ dk,
117
+ cos,
118
+ sin,
119
+ seqlen_offsets=seqlen_offsets,
120
+ cu_seqlens=cu_seqlens,
121
+ max_seqlen=ctx.max_seqlen,
122
+ interleaved=ctx.interleaved,
123
+ inplace=True,
124
+ conjugate=True,
125
+ )
126
+
127
+ return do, None, None, None, None, None, None
128
+
129
+
130
+ def apply_rotary_emb_unpad(
131
+ qkv,
132
+ cos,
133
+ sin,
134
+ interleaved=False,
135
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
136
+ cu_seqlens: Optional[torch.Tensor] = None,
137
+ max_seqlen: Optional[int] = None,
138
+ ):
139
+ """
140
+ Arguments:
141
+ qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
142
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
143
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
144
+ of 1st half and 2nd half (GPT-NeoX style).
145
+ inplace: if True, apply rotary embedding in-place.
146
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
147
+ Most commonly used in inference when we have KV cache.
148
+ cu_seqlens: (batch + 1,) or None
149
+ max_seqlen: int
150
+ Return:
151
+ out: (total_nnz, dim)
152
+ rotary_dim must be <= headdim
153
+ Apply rotary embedding to the first rotary_dim of x.
154
+ """
155
+ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, interleaved, seqlen_offsets, cu_seqlens, max_seqlen)
156
+
157
+
158
+ class UnpaddedRotaryEmbedding(torch.nn.Module):
159
+ """
160
+ The rotary position embeddings applied directly to unpadded sequences.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ dim: int,
166
+ base: float = 10000.0,
167
+ interleaved: bool = False,
168
+ max_seqlen: Optional[int] = None,
169
+ scale_base: Optional[bool] = None,
170
+ pos_idx_in_fp32: bool = True,
171
+ device: Optional[torch.device] = None,
172
+ dtype: Optional[torch.dtype] = None,
173
+ ):
174
+ """
175
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
176
+ of 1st half and 2nd half (GPT-NeoX style).
177
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
178
+ otherwise they might be in lower precision.
179
+ This option was added because previously (before 2023-07-02), when we construct
180
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
181
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
182
+ self.inv_freq would be bf16, and the position indices are also in bf16.
183
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
184
+ embeddings for some positions will coincide.
185
+ To maintain compatibility with models previously trained in pure bf16,
186
+ we add this option.
187
+ max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
188
+ up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
189
+ the cos_sin_cache wll be recomputed during the forward pass.
190
+ """
191
+ super().__init__()
192
+ self.dim = dim
193
+ self.base = float(base)
194
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
195
+ # Generate and save the inverse frequency buffer (non trainable)
196
+ inv_freq = self._compute_inv_freq(device)
197
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
198
+ self.interleaved = interleaved
199
+ self.scale_base = scale_base
200
+ scale = (
201
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
202
+ if scale_base is not None
203
+ else None
204
+ )
205
+ self.register_buffer("scale", scale, persistent=False)
206
+
207
+ self._seq_len_cached = 0
208
+ self._cos_cached = None
209
+ self._sin_cached = None
210
+ self._cos_k_cached = None
211
+ self._sin_k_cached = None
212
+
213
+ if max_seqlen is not None and device is not None and dtype is not None:
214
+ self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
215
+
216
+ def _compute_inv_freq(self, device=None):
217
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
218
+
219
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
220
+ # Reset the tables if the sequence length has changed,
221
+ # if we're on a new device (possibly due to tracing for instance),
222
+ # or if we're switching from inference mode to training
223
+ if (
224
+ seqlen > self._seq_len_cached
225
+ or self._cos_cached is None
226
+ or self._cos_cached.device != device
227
+ or self._cos_cached.dtype != dtype
228
+ or (self.training and self._cos_cached.is_inference())
229
+ ):
230
+ self._seq_len_cached = seqlen
231
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
232
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
233
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
234
+ if self.pos_idx_in_fp32:
235
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
236
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
237
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
238
+ # cos & sin output to change significantly.
239
+ # We want to recompute self.inv_freq if it was not loaded in fp32
240
+ if self.inv_freq.dtype != torch.float32:
241
+ inv_freq = self._compute_inv_freq(device=device)
242
+ else:
243
+ inv_freq = self.inv_freq
244
+ else:
245
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
246
+ inv_freq = self.inv_freq
247
+ # Don't do einsum, it converts fp32 to fp16 under AMP
248
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
249
+ freqs = torch.outer(t, inv_freq)
250
+ if self.scale is None:
251
+ self._cos_cached = torch.cos(freqs).to(dtype)
252
+ self._sin_cached = torch.sin(freqs).to(dtype)
253
+ else:
254
+ power = (
255
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
256
+ ) / self.scale_base
257
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
258
+ # We want the multiplication by scale to happen in fp32
259
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
260
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
261
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
262
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
263
+
264
+ def forward(
265
+ self,
266
+ qkv: torch.Tensor,
267
+ cu_seqlens: torch.Tensor,
268
+ max_seqlen: Optional[int] = None,
269
+ seqlen_offset: Union[int, torch.Tensor] = 0,
270
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
271
+ """
272
+ qkv: (total_nnz, 3, nheads, headdim)
273
+ cu_seqlens: (batch + 1,) cumulative sequence lengths
274
+ max_seqlen: int max seq length in the batch
275
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
276
+ Most commonly used in inference when we have KV cache.
277
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
278
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
279
+ Apply rotary embedding *inplace* to qkv.
280
+ """
281
+ if max_seqlen is not None:
282
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
283
+
284
+ qkv = apply_rotary_emb_unpad(
285
+ qkv,
286
+ self._cos_cached,
287
+ self._sin_cached,
288
+ interleaved=self.interleaved,
289
+ seqlen_offsets=seqlen_offset,
290
+ cu_seqlens=cu_seqlens,
291
+ max_seqlen=max_seqlen,
292
+ )
293
+
294
+ return qkv
295
+
296
+ def extra_repr(self) -> str:
297
+ return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Optuna, Hugging Face
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2023 OLMo Authors
5
+ # License: Apache-2.0
6
+
7
+ import functools
8
+ import logging
9
+ from enum import Enum
10
+
11
+
12
+ @functools.lru_cache(None)
13
+ def warning_once(self, *args, **kwargs):
14
+ """
15
+ This method is identical to `logger.warning()`, but will emit the warning with the same message only once
16
+
17
+ Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
18
+ The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
19
+ another type of cache that includes the caller frame information in the hashing function.
20
+ """
21
+ self.warning(*args, **kwargs)
22
+
23
+
24
+ logging.Logger.warning_once = warning_once
25
+ logging.Logger.warn_once = warning_once
26
+
27
+
28
+ class StrEnum(str, Enum):
29
+ """
30
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
31
+ We include this here for compatibility with older version of Python.
32
+ """
33
+
34
+ def __str__(self) -> str:
35
+ return self.value
36
+
37
+ def __repr__(self) -> str:
38
+ return f"'{str(self)}'"