Text Generation
Transformers
PyTorch
mosaic_gpt
custom_code
daking commited on
Commit
2d4410c
1 Parent(s): ec4ad69

Upload MosaicGPT

Browse files
attention.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Attention layers."""
5
+
6
+ import math
7
+ import warnings
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+ from torch import nn
14
+
15
+ from .low_precision_layernorm import LPLayerNorm
16
+
17
+
18
+ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
19
+ original_is_causal: bool):
20
+ if original_is_causal and num_query_tokens != num_key_tokens:
21
+ if num_query_tokens != 1:
22
+ raise NotImplementedError(
23
+ 'MosaicGPT does not support query and key with different number of tokens, unless number of query tokens is 1.'
24
+ )
25
+ else:
26
+ return False
27
+ return original_is_causal
28
+
29
+
30
+ def scaled_multihead_dot_product_attention(
31
+ query,
32
+ key,
33
+ value,
34
+ n_heads,
35
+ softmax_scale=None,
36
+ attn_bias=None,
37
+ key_padding_mask=None,
38
+ is_causal=False,
39
+ dropout_p=0.0,
40
+ training=False,
41
+ needs_weights=False,
42
+ ):
43
+
44
+ q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
45
+ k = rearrange(key, 'b s (h d) -> b h d s', h=n_heads) # includes key.t()
46
+ v = rearrange(value, 'b s (h d) -> b h s d', h=n_heads)
47
+
48
+ min_val = torch.finfo(q.dtype).min
49
+
50
+ b, _, s_q, d = q.shape
51
+ s_k = k.size(-1)
52
+
53
+ if softmax_scale is None:
54
+ softmax_scale = 1 / math.sqrt(d)
55
+
56
+ attn_weight = q.matmul(k) * softmax_scale
57
+
58
+ if attn_bias is not None:
59
+ if (attn_bias.size(-1) != 1 and
60
+ attn_bias.size(-1) != s_k) or (attn_bias.size(-2) != 1 and
61
+ attn_bias.size(-2) != s_q):
62
+ raise RuntimeError(
63
+ f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.'
64
+ )
65
+ attn_weight = attn_weight + attn_bias
66
+
67
+ if key_padding_mask is not None:
68
+ if attn_bias is not None:
69
+ warnings.warn(
70
+ 'Propogating key_padding_mask to the attention module ' +\
71
+ 'and applying it within the attention module can cause ' +\
72
+ 'unneccessary computation/memory usage. Consider integrating ' +\
73
+ 'into attn_bias once and passing that to each attention ' +\
74
+ 'module instead.'
75
+ )
76
+ attn_weight = attn_weight.masked_fill(
77
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val)
78
+
79
+ if is_causal:
80
+ s = max(s_q, s_k)
81
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
82
+ causal_mask = causal_mask.tril()
83
+ causal_mask = causal_mask.to(torch.bool)
84
+ causal_mask = ~causal_mask
85
+ causal_mask = causal_mask[-s_q:, -s_k:]
86
+ attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
87
+ min_val)
88
+
89
+ attn_weight = torch.softmax(attn_weight, dim=-1)
90
+
91
+ if dropout_p:
92
+ attn_weight = torch.nn.functional.dropout(attn_weight,
93
+ p=dropout_p,
94
+ training=training,
95
+ inplace=True)
96
+
97
+ out = attn_weight.matmul(v)
98
+ out = rearrange(out, 'b h s d -> b s (h d)')
99
+
100
+ if needs_weights:
101
+ return out, attn_weight
102
+ return out, None
103
+
104
+
105
+ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
106
+ for tensor in tensors:
107
+ if tensor.dtype not in valid_dtypes:
108
+ raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.')
109
+ if not tensor.is_cuda:
110
+ raise TypeError(f'Inputs must be cuda tensors ({tensor.is_cuda=}).')
111
+
112
+
113
+ def flash_attn_fn(
114
+ query,
115
+ key,
116
+ value,
117
+ n_heads,
118
+ softmax_scale=None,
119
+ attn_bias=None,
120
+ key_padding_mask=None,
121
+ is_causal=False,
122
+ dropout_p=0.0,
123
+ training=False,
124
+ needs_weights=False,
125
+ ):
126
+ try:
127
+ from flash_attn import bert_padding, flash_attn_interface
128
+ except:
129
+ raise RuntimeError('Please install flash_attn==0.2.8')
130
+
131
+ check_valid_inputs(query, key, value)
132
+
133
+ if attn_bias is not None:
134
+ raise NotImplementedError(f'attn_bias not implemented for flash attn.')
135
+
136
+ batch_size, seqlen = query.shape[:2]
137
+
138
+ if key_padding_mask is None:
139
+ key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
140
+ query_padding_mask = key_padding_mask[:, -query.size(1):]
141
+
142
+ query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
143
+ query, query_padding_mask)
144
+ query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
145
+
146
+ key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
147
+ key, key_padding_mask)
148
+ key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
149
+
150
+ value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
151
+ value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
152
+
153
+ dropout_p = dropout_p if training else 0.0
154
+
155
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
156
+
157
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
158
+ query_unpad,
159
+ key_unpad,
160
+ value_unpad,
161
+ cu_seqlens_q,
162
+ cu_seqlens_k,
163
+ max_seqlen_q,
164
+ max_seqlen_k,
165
+ dropout_p,
166
+ softmax_scale=softmax_scale,
167
+ causal=reset_is_causal,
168
+ return_attn_probs=needs_weights)
169
+
170
+ output = bert_padding.pad_input(
171
+ rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
172
+ seqlen)
173
+ return output, None
174
+
175
+
176
+ def triton_flash_attn_fn(
177
+ query,
178
+ key,
179
+ value,
180
+ n_heads,
181
+ softmax_scale=None,
182
+ attn_bias=None,
183
+ key_padding_mask=None,
184
+ is_causal=False,
185
+ dropout_p=0.0,
186
+ training=False,
187
+ needs_weights=False,
188
+ ):
189
+ try:
190
+ from flash_attn import flash_attn_triton # type: ignore
191
+ except:
192
+ raise RuntimeError('Please install flash_attn==0.2.8 and triton==2.0.0.dev20221202.')
193
+
194
+ check_valid_inputs(query, key, value)
195
+
196
+ if dropout_p:
197
+ raise NotImplementedError(
198
+ f'Dropout not implemented for attn_impl: triton.')
199
+
200
+ if needs_weights:
201
+ raise NotImplementedError(
202
+ f'attn_impl: triton cannot return attn weights.')
203
+
204
+ if key_padding_mask is not None:
205
+ warnings.warn(
206
+ 'Propagating key_padding_mask to the attention module ' +\
207
+ 'and applying it within the attention module can cause ' +\
208
+ 'unnecessary computation/memory usage. Consider integrating ' +\
209
+ 'into attn_bias once and passing that to each attention ' +\
210
+ 'module instead.'
211
+ )
212
+ b_size, s_k = key_padding_mask.shape[:2]
213
+
214
+ if attn_bias is None:
215
+ attn_bias = query.new_zeros(b_size, 1, 1, s_k)
216
+
217
+ attn_bias = attn_bias.masked_fill(
218
+ ~key_padding_mask.view((b_size, 1, 1, s_k)),
219
+ torch.finfo(query.dtype).min)
220
+
221
+ query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
222
+ key = rearrange(key, 'b s (h d) -> b s h d', h=n_heads)
223
+ value = rearrange(value, 'b s (h d) -> b s h d', h=n_heads)
224
+
225
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
226
+ attn_output = flash_attn_triton.flash_attn_func(query, key, value,
227
+ attn_bias, reset_is_causal,
228
+ softmax_scale)
229
+
230
+ output = attn_output.view(*attn_output.shape[:2], -1)
231
+
232
+ return output, None
233
+
234
+
235
+ class MultiheadAttention(nn.Module):
236
+ """Multi-head self attention.
237
+
238
+ Using torch or triton attention implemetation enables user to also use
239
+ additive bias.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ d_model: int,
245
+ n_heads: int,
246
+ attn_impl: str = 'triton',
247
+ attn_clip_qkv: Optional[float] = None,
248
+ attn_qk_ln: bool = False,
249
+ softmax_scale: Optional[float] = None,
250
+ attn_pdrop: float = 0.0,
251
+ low_precision_layernorm: bool = False,
252
+ device: Optional[str] = None,
253
+ ):
254
+ super().__init__()
255
+
256
+ self.attn_impl = attn_impl
257
+ self.clip_qkv = attn_clip_qkv
258
+ self.attn_qk_ln = attn_qk_ln
259
+
260
+ self.d_model = d_model
261
+ self.n_heads = n_heads
262
+ self.softmax_scale = softmax_scale
263
+ if self.softmax_scale is None:
264
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
265
+ self.attn_dropout_p = attn_pdrop
266
+
267
+ self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
268
+ # for param init fn; enables shape based init of fused layers
269
+ fuse_splits = (d_model, 2 * d_model)
270
+ self.Wqkv._fused = (0, fuse_splits) # type: ignore
271
+
272
+ if self.attn_qk_ln:
273
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
274
+ self.q_ln = layernorm_class(self.d_model, device=device)
275
+ self.k_ln = layernorm_class(self.d_model, device=device)
276
+
277
+ if self.attn_impl == 'flash':
278
+ self.attn_fn = flash_attn_fn
279
+ elif self.attn_impl == 'triton':
280
+ self.attn_fn = triton_flash_attn_fn
281
+ warnings.warn(
282
+ 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +\
283
+ 'it uses more memory. When training larger models this can trigger ' +\
284
+ 'alloc retries which hurts performance. If encountered, we recommend ' +\
285
+ 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
286
+ elif self.attn_impl == 'torch':
287
+ self.attn_fn = scaled_multihead_dot_product_attention
288
+ if torch.cuda.is_available():
289
+ warnings.warn(
290
+ 'Using `attn_impl: torch`. If your model does not use `alibi` or ' +\
291
+ '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +\
292
+ 'we recommend using `attn_impl: triton`.'
293
+ )
294
+ else:
295
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
296
+
297
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
298
+ self.out_proj._is_residual = True # type: ignore
299
+
300
+ def forward(self,
301
+ x,
302
+ past_key_value=None,
303
+ attn_bias=None,
304
+ attention_mask=None,
305
+ is_causal=True,
306
+ needs_weights=False):
307
+ qkv = self.Wqkv(x)
308
+
309
+ if self.clip_qkv:
310
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
311
+
312
+ query, key, value = qkv.chunk(3, dim=2)
313
+
314
+ key_padding_mask = attention_mask
315
+
316
+ if self.attn_qk_ln:
317
+ # Applying layernorm to qk
318
+ dtype = query.dtype
319
+ query = self.q_ln(query).to(dtype)
320
+ key = self.k_ln(key).to(dtype)
321
+
322
+ if past_key_value is not None:
323
+ if len(past_key_value) != 0:
324
+ key = torch.cat([past_key_value[0], key], dim=1)
325
+ value = torch.cat([past_key_value[1], value], dim=1)
326
+
327
+ past_key_value = (key, value)
328
+
329
+ if attn_bias is not None:
330
+ attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
331
+
332
+ context, attn_weights = self.attn_fn(
333
+ query,
334
+ key,
335
+ value,
336
+ self.n_heads,
337
+ softmax_scale=self.softmax_scale,
338
+ attn_bias=attn_bias,
339
+ key_padding_mask=key_padding_mask,
340
+ is_causal=is_causal,
341
+ dropout_p=self.attn_dropout_p,
342
+ training=self.training,
343
+ needs_weights=needs_weights,
344
+ )
345
+
346
+ return self.out_proj(context), attn_weights, past_key_value
347
+
348
+
349
+ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal,
350
+ use_sequence_id):
351
+ if attn_impl == 'flash':
352
+ return None
353
+ elif attn_impl in ['torch', 'triton']:
354
+ if alibi:
355
+ if (prefix_lm or not causal) or use_sequence_id:
356
+ return (1, n_heads, seq_len, seq_len)
357
+ return (1, n_heads, 1, seq_len)
358
+ elif prefix_lm or use_sequence_id:
359
+ return (1, 1, seq_len, seq_len)
360
+ return None
361
+ else:
362
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
363
+
364
+
365
+ def attn_bias(attn_impl,
366
+ attn_bias,
367
+ n_heads,
368
+ seq_len,
369
+ causal=False,
370
+ alibi=False,
371
+ alibi_bias_max=8):
372
+ if attn_impl == 'flash':
373
+ return None
374
+ elif attn_impl in ['torch', 'triton']:
375
+ if alibi:
376
+ # in place add alibi to attn bias
377
+ device, dtype = attn_bias.device, attn_bias.dtype
378
+ attn_bias = attn_bias.add(
379
+ alibi_bias(n_heads,
380
+ seq_len,
381
+ full=not causal,
382
+ alibi_bias_max=alibi_bias_max,
383
+ device=device,
384
+ dtype=dtype))
385
+ return attn_bias
386
+ else:
387
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
388
+
389
+
390
+ def alibi_bias(n_heads,
391
+ seq_len,
392
+ full=False,
393
+ alibi_bias_max=8,
394
+ device=None,
395
+ dtype=None):
396
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=dtype,
397
+ device=device).view(1, 1, 1, seq_len)
398
+ if full:
399
+ # generate 1 x Heads x SeqLen x SeqLen alibi bias mask
400
+ # otherwise the mask is 1 x Heads x 1 x SeqLen (which is broadcast to the appropriate size)
401
+ alibi_bias = alibi_bias - torch.arange(
402
+ 1 - seq_len, 1, dtype=dtype, device=device).view(1, 1, seq_len, 1)
403
+ alibi_bias = alibi_bias.abs().mul(-1)
404
+
405
+ m = torch.arange(1, n_heads + 1, dtype=dtype, device=device)
406
+ m = m.mul(alibi_bias_max / n_heads)
407
+ alibi_bias = alibi_bias * (1. / (2**m.view(1, n_heads, 1, 1)))
408
+ return alibi_bias
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "mosaicml/mosaic-llama-redpajama-final-candidate",
3
+ "alibi": true,
4
+ "alibi_bias_max": 8,
5
+ "architectures": [
6
+ "MosaicGPT"
7
+ ],
8
+ "attn_clip_qkv": null,
9
+ "attn_impl": "torch",
10
+ "attn_pdrop": 0,
11
+ "attn_qk_ln": true,
12
+ "attn_uses_sequence_id": false,
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_mosaic_gpt.MosaicGPTConfig",
15
+ "AutoModelForCausalLM": "mosaic_gpt.MosaicGPT"
16
+ },
17
+ "d_model": 2048,
18
+ "emb_init_std": null,
19
+ "emb_init_uniform_lim": null,
20
+ "emb_pdrop": 0,
21
+ "embedding_fraction": 1.0,
22
+ "fan_mode": "fan_in",
23
+ "init_device": "cpu",
24
+ "init_div_is_residual": true,
25
+ "init_gain": 0,
26
+ "init_nonlinearity": "relu",
27
+ "init_std": 0.02,
28
+ "logit_scale": null,
29
+ "low_precision_layernorm": true,
30
+ "max_seq_len": 2048,
31
+ "mlp_ratio": 4,
32
+ "model_type": "mosaic_gpt",
33
+ "n_heads": 16,
34
+ "n_layers": 24,
35
+ "no_bias": true,
36
+ "param_init_fn": "kaiming_normal_",
37
+ "prefix_lm": false,
38
+ "resid_pdrop": 0,
39
+ "softmax_scale": null,
40
+ "tokenizer_name": "EleutherAI/gpt-neox-20b",
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.27.4",
43
+ "use_cache": false,
44
+ "verbose": 0,
45
+ "vocab_size": 50432
46
+ }
configuration_mosaic_gpt.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """A HuggingFace-style model configuration."""
5
+
6
+ from typing import Optional, Tuple, Union
7
+
8
+ from transformers import PretrainedConfig
9
+
10
+
11
+ class MosaicGPTConfig(PretrainedConfig):
12
+ model_type = 'mosaic_gpt'
13
+
14
+ def __init__(
15
+ self,
16
+ d_model: int = 2048,
17
+ n_heads: int = 16,
18
+ n_layers: int = 24,
19
+ mlp_ratio: int = 4,
20
+ max_seq_len: int = 2048,
21
+ vocab_size: int = 50368,
22
+ attn_pdrop: float = 0.0,
23
+ resid_pdrop: float = 0.0,
24
+ emb_pdrop: float = 0.0,
25
+ attn_impl: str = 'triton',
26
+ attn_qk_ln: bool = False,
27
+ attn_clip_qkv: Optional[float] = None,
28
+ softmax_scale: Optional[float] = None,
29
+ prefix_lm: Optional[bool] = False,
30
+ attn_uses_sequence_id: Optional[bool] = False,
31
+ alibi: bool = False,
32
+ alibi_bias_max: int = 8,
33
+ init_device: str = 'cpu',
34
+ logit_scale: Optional[Union[float, str]] = None,
35
+ no_bias: bool = False,
36
+ verbose: int = 0,
37
+ param_init_fn: str = 'kaiming_normal_',
38
+ init_div_is_residual: Union[int, float, str, bool] = True,
39
+ init_std: float = 0.02,
40
+ emb_init_std: Optional[float] = None,
41
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float],
42
+ float]] = None,
43
+ init_gain: float = 0,
44
+ fan_mode: str = 'fan_in',
45
+ init_nonlinearity: str = 'relu',
46
+ embedding_fraction: float = 1.0,
47
+ low_precision_layernorm: bool = True,
48
+ use_cache: bool = False,
49
+ **kwargs,
50
+ ):
51
+ """The MosaicGPT configuration class.
52
+
53
+ Args:
54
+ d_model (int): The size of the embedding dimension of the model.
55
+ n_heads (int): The number of attention heads.
56
+ n_layers (int): The number of layers in the model.
57
+ mlp_ratio (int): The ratio of the up/down scale in the MLP.
58
+ max_seq_len (int): The maximum sequence length of the model.
59
+ vocab_size (int): The size of the vocabulary.
60
+ attn_pdrop (float): The dropout probability for the attention layers.
61
+ resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
62
+ emb_pdrop (float): The dropout probability for the embedding layer.
63
+ attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
64
+ attn_qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
65
+ attn_clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
66
+ this value.
67
+ softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
68
+ use the default scale of ``1/sqrt(d_keys)``.
69
+ prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
70
+ extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
71
+ can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
72
+ attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
73
+ When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
74
+ which sub-sequence each token belongs to.
75
+ Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
76
+ alibi (bool): Whether to use the alibi bias instead of position embeddings.
77
+ alibi_bias_max (int): The maximum value of the alibi bias.
78
+ init_device (str): The device to use for parameter initialization.
79
+ logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
80
+ no_bias (bool): Whether to use bias in all layers.
81
+ verbose (int): The verbosity level. 0 is silent.
82
+ param_init_fn (str): The parameter initialization scheme to use. One of 'default_', 'baseline_', 'kaiming_uniform_',
83
+ 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 'xavier_normal_'.
84
+ init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
85
+ init_std (float): The standard deviation of the normal distribution used to initialize the model,
86
+ if using the baseline_ parameter initialization scheme.
87
+ emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
88
+ emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
89
+ used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
90
+ init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
91
+ fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
92
+ init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
93
+ embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
94
+ low_precision_layernorm (bool): Whether to use low precision layer normalization.
95
+ use_cache (bool): Whether or not the model should return the last key/values attentions
96
+ """
97
+ self.d_model = d_model
98
+ self.n_heads = n_heads
99
+ self.n_layers = n_layers
100
+ self.mlp_ratio = mlp_ratio
101
+ self.max_seq_len = max_seq_len
102
+ self.vocab_size = vocab_size
103
+ self.attn_pdrop = attn_pdrop
104
+ self.resid_pdrop = resid_pdrop
105
+ self.emb_pdrop = emb_pdrop
106
+ self.attn_impl = attn_impl
107
+ self.attn_qk_ln = attn_qk_ln
108
+ self.attn_clip_qkv = attn_clip_qkv
109
+ self.softmax_scale = softmax_scale
110
+ self.prefix_lm = prefix_lm
111
+ self.attn_uses_sequence_id = attn_uses_sequence_id
112
+ self.alibi = alibi
113
+ self.alibi_bias_max = alibi_bias_max
114
+ self.init_device = init_device
115
+ self.logit_scale = logit_scale
116
+ self.no_bias = no_bias
117
+ self.verbose = verbose
118
+ self.param_init_fn = param_init_fn
119
+ self.init_div_is_residual = init_div_is_residual
120
+ self.init_std = init_std
121
+ self.emb_init_std = emb_init_std
122
+ self.emb_init_uniform_lim = emb_init_uniform_lim
123
+ self.init_std = init_std
124
+ self.init_gain = init_gain
125
+ self.fan_mode = fan_mode
126
+ self.init_nonlinearity = init_nonlinearity
127
+ self.embedding_fraction = embedding_fraction
128
+ self.low_precision_layernorm = low_precision_layernorm
129
+ self.use_cache = use_cache
130
+ if 'name' in kwargs:
131
+ del kwargs['name']
132
+ if 'loss_fn' in kwargs:
133
+ del kwargs['loss_fn']
134
+ super().__init__(**kwargs)
135
+
136
+ self._validate_config()
137
+
138
+ def _validate_config(self):
139
+ if self.d_model % self.n_heads != 0:
140
+ raise ValueError('d_model must be divisible by n_heads')
141
+ if any(prob < 0 or prob > 1
142
+ for prob in [self.attn_pdrop, self.resid_pdrop, self.emb_pdrop]):
143
+ raise ValueError(
144
+ 'attn_pdrop, resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1'
145
+ )
146
+ if self.attn_impl not in ['torch', 'flash', 'triton']:
147
+ raise ValueError(f'Unknown attn_impl={self.attn_impl}')
148
+ if self.prefix_lm and self.attn_impl not in ['torch', 'triton']:
149
+ raise NotImplementedError(
150
+ 'prefix_lm only implemented with torch and triton attention.')
151
+ if self.alibi and self.attn_impl not in ['torch', 'triton']:
152
+ raise NotImplementedError(
153
+ 'alibi only implemented with torch and triton attention.')
154
+ if self.attn_uses_sequence_id and self.attn_impl not in [
155
+ 'torch', 'triton'
156
+ ]:
157
+ raise NotImplementedError(
158
+ 'attn_uses_sequence_id only implemented with torch and triton attention.'
159
+ )
160
+ if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
161
+ raise ValueError(
162
+ 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
163
+ )
164
+ if isinstance(self.logit_scale,
165
+ str) and self.logit_scale != 'inv_sqrt_d_model':
166
+ raise ValueError(
167
+ f"{self.logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
168
+ )
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.27.4",
4
+ "use_cache": false
5
+ }
gpt_blocks.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """GPT Blocks used for the GPT Model."""
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from .attention import MultiheadAttention
12
+ from .low_precision_layernorm import LPLayerNorm
13
+
14
+
15
+ class GPTMLP(nn.Module):
16
+
17
+ def __init__(self,
18
+ d_model: int,
19
+ mlp_ratio: int,
20
+ device: Optional[str] = None):
21
+ super().__init__()
22
+ self.mlp_up = nn.Linear(d_model, mlp_ratio * d_model, device=device)
23
+ self.mlp_act = nn.GELU(approximate='none')
24
+ self.mlp_down = nn.Linear(mlp_ratio * d_model, d_model, device=device)
25
+ self.mlp_down._is_residual = True # type: ignore
26
+
27
+ def forward(self, x):
28
+ return self.mlp_down(self.mlp_act(self.mlp_up(x)))
29
+
30
+
31
+ class GPTBlock(nn.Module):
32
+
33
+ def __init__(self,
34
+ attn_impl: str,
35
+ d_model: int,
36
+ n_heads: int,
37
+ mlp_ratio: int,
38
+ attn_clip_qkv: Optional[float] = None,
39
+ attn_qk_ln: bool = False,
40
+ softmax_scale: Optional[float] = None,
41
+ attn_pdrop: float = 0.0,
42
+ alibi: bool = False,
43
+ resid_pdrop: float = 0.0,
44
+ low_precision_layernorm: bool = False,
45
+ device: Optional[str] = None,
46
+ **kwargs):
47
+ del kwargs # unused, just to capture any extra args from the config
48
+ super().__init__()
49
+
50
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
51
+
52
+ self.ln_1 = layernorm_class(d_model, device=device)
53
+ self.attn = MultiheadAttention(
54
+ attn_impl=attn_impl,
55
+ attn_clip_qkv=attn_clip_qkv,
56
+ attn_qk_ln=attn_qk_ln,
57
+ softmax_scale=softmax_scale,
58
+ attn_pdrop=attn_pdrop,
59
+ d_model=d_model,
60
+ n_heads=n_heads,
61
+ device=device,
62
+ )
63
+ self.ln_2 = layernorm_class(d_model, device=device)
64
+ self.mlp = GPTMLP(
65
+ d_model=d_model,
66
+ mlp_ratio=mlp_ratio,
67
+ device=device,
68
+ )
69
+ self.resid_attn_dropout = nn.Dropout(resid_pdrop)
70
+ self.resid_mlp_dropout = nn.Dropout(resid_pdrop)
71
+
72
+ def forward(
73
+ self,
74
+ x: torch.Tensor,
75
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
76
+ attn_bias: Optional[torch.Tensor] = None,
77
+ attention_mask: Optional[torch.ByteTensor] = None,
78
+ is_causal: bool = True,
79
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
80
+ a = self.ln_1(x)
81
+ b, _, past_key_value = self.attn(a,
82
+ past_key_value=past_key_value,
83
+ attn_bias=attn_bias,
84
+ attention_mask=attention_mask,
85
+ is_causal=is_causal)
86
+ x = x + self.resid_attn_dropout(b)
87
+ m = self.ln_2(x)
88
+ n = self.mlp(m)
89
+ x = x + self.resid_mlp_dropout(n)
90
+ return x, past_key_value
low_precision_layernorm.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ class LPLayerNorm(torch.nn.LayerNorm):
5
+ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
6
+ super().__init__(
7
+ normalized_shape=normalized_shape,
8
+ eps=eps,
9
+ elementwise_affine=elementwise_affine,
10
+ device=device,
11
+ dtype=dtype,
12
+ )
13
+
14
+ def forward(self, x):
15
+ module_device = x.device
16
+ downcast_x = _cast_if_autocast_enabled(x)
17
+ downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
18
+ downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
19
+ with torch.autocast(enabled=False, device_type=module_device.type):
20
+ return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
21
+
22
+ def _cast_if_autocast_enabled(tensor):
23
+ if torch.is_autocast_enabled():
24
+ if tensor.device.type == 'cuda':
25
+ dtype = torch.get_autocast_gpu_dtype()
26
+ elif tensor.device.type == 'cpu':
27
+ dtype = torch.get_autocast_cpu_dtype()
28
+ else:
29
+ raise NotImplementedError()
30
+ return tensor.to(dtype=dtype)
31
+ return tensor
mosaic_gpt.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """A simple, flexible implementation of a GPT model.
5
+
6
+ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
7
+ """
8
+
9
+ import math
10
+ import warnings
11
+ from typing import List, Optional, Tuple
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from transformers import AutoTokenizer, PreTrainedModel
17
+ from transformers.modeling_outputs import CausalLMOutputWithPast
18
+
19
+ from .attention import attn_bias as module_attn_bias, attn_bias_shape as module_attn_bias_shape
20
+ from .gpt_blocks import GPTBlock
21
+ from .configuration_mosaic_gpt import \
22
+ MosaicGPTConfig
23
+ from .param_init_fns import MODEL_INIT_REGISTRY
24
+ from .low_precision_layernorm import LPLayerNorm
25
+
26
+
27
+ class MosaicGPT(PreTrainedModel):
28
+ config_class = MosaicGPTConfig
29
+ base_model_prefix = 'mosaic_gpt'
30
+
31
+ def __init__(self, config: MosaicGPTConfig):
32
+ super().__init__(config)
33
+
34
+ self.attn_impl = config.attn_impl
35
+ self.prefix_lm = config.prefix_lm
36
+ self.attn_uses_sequence_id = config.attn_uses_sequence_id
37
+ self.alibi = config.alibi
38
+ self.alibi_bias_max = config.alibi_bias_max
39
+
40
+ layernorm_class = LPLayerNorm if config.low_precision_layernorm else nn.LayerNorm
41
+
42
+ # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
43
+ # both report this helping with stabilizing training
44
+ self.embedding_fraction = config.embedding_fraction
45
+
46
+ self.transformer = nn.ModuleDict({
47
+ 'wte':
48
+ nn.Embedding(config.vocab_size,
49
+ config.d_model,
50
+ device=config.init_device)
51
+ })
52
+ if not self.alibi:
53
+ self.transformer.update({
54
+ 'wpe':
55
+ nn.Embedding(config.max_seq_len,
56
+ config.d_model,
57
+ device=config.init_device)
58
+ })
59
+ self.transformer.update({'emb_drop': nn.Dropout(config.emb_pdrop)})
60
+ self.transformer.update({
61
+ 'blocks':
62
+ nn.ModuleList([
63
+ GPTBlock(device=config.init_device,
64
+ **config.to_dict())
65
+ for _ in range(config.n_layers)
66
+ ])
67
+ })
68
+ self.transformer.update({
69
+ 'ln_f': layernorm_class(config.d_model, device=config.init_device)
70
+ })
71
+
72
+ # enables scaling output logits; similar to a softmax "temperature"
73
+ # PaLM paper uses scale 1/sqrt(config.d_model)
74
+ self.logit_scale = None
75
+ if config.logit_scale is not None:
76
+ logit_scale = config.logit_scale
77
+ if isinstance(logit_scale, str):
78
+ if logit_scale == 'inv_sqrt_d_model':
79
+ logit_scale = 1 / math.sqrt(config.d_model)
80
+ else:
81
+ raise ValueError(
82
+ f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
83
+ )
84
+ self.logit_scale = logit_scale
85
+
86
+ if config.init_device != 'meta':
87
+ print(
88
+ f'You are using {config.init_device=}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
89
+ )
90
+ self.apply(self.param_init_fn)
91
+
92
+ self.is_causal = not self.prefix_lm
93
+
94
+ # define attn mask
95
+ self._attn_bias_initialized = False
96
+ self.attn_bias = None
97
+ self.attn_bias_shape = module_attn_bias_shape(
98
+ self.attn_impl,
99
+ config.n_heads,
100
+ config.max_seq_len,
101
+ self.alibi,
102
+ prefix_lm=self.prefix_lm,
103
+ causal=self.is_causal,
104
+ use_sequence_id=self.attn_uses_sequence_id)
105
+
106
+ if config.no_bias:
107
+ for module in self.modules():
108
+ if hasattr(module, 'bias') and isinstance(
109
+ module.bias, nn.Parameter):
110
+ if config.verbose:
111
+ print(f'Removing bias ({module.bias}) from {module}.')
112
+ module.register_parameter('bias', None)
113
+
114
+ if config.verbose and config.verbose > 2:
115
+ print(self)
116
+
117
+ @torch.no_grad()
118
+ def _attn_bias(self,
119
+ device,
120
+ dtype,
121
+ attention_mask: Optional[torch.ByteTensor] = None,
122
+ prefix_mask: Optional[torch.ByteTensor] = None,
123
+ sequence_id: Optional[torch.LongTensor] = None):
124
+ if not self._attn_bias_initialized:
125
+ if self.attn_bias_shape:
126
+ self.attn_bias = torch.zeros(self.attn_bias_shape,
127
+ device=device,
128
+ dtype=dtype)
129
+ self.attn_bias = module_attn_bias(
130
+ self.attn_impl,
131
+ self.attn_bias,
132
+ self.config.n_heads,
133
+ self.config.max_seq_len,
134
+ causal=self.is_causal,
135
+ alibi=self.alibi,
136
+ alibi_bias_max=self.alibi_bias_max)
137
+ self._attn_bias_initialized = True
138
+
139
+ # flash does not support prefix_lm and will incorporate any
140
+ # attention_mask inside the attention module
141
+ if self.attn_impl == 'flash':
142
+ return self.attn_bias, attention_mask
143
+
144
+ attn_bias = self.attn_bias
145
+
146
+ # If using torch or triton, we incorporate the prefix_mask (if appropriate)
147
+ if self.prefix_lm:
148
+ assert isinstance(attn_bias, torch.Tensor) # pyright
149
+ assert isinstance(prefix_mask, torch.Tensor) # pyright
150
+ attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
151
+
152
+ # If using torch or triton, we incorporate sequence_id (if appropriate)
153
+ if self.attn_uses_sequence_id and sequence_id is not None:
154
+ assert isinstance(attn_bias, torch.Tensor) # pyright
155
+ attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
156
+
157
+ # If using torch or triton, we incorporate attention_mask. This will output
158
+ # None in place of attention_mask since it will not be further needed in the
159
+ # attention modules.
160
+ if attention_mask is not None:
161
+ s_k = attention_mask.shape[-1]
162
+ if attn_bias is None:
163
+ attn_bias = torch.zeros((1, 1, 1, s_k),
164
+ device=device,
165
+ dtype=dtype)
166
+ else:
167
+ attn_bias = attn_bias[:, :, :, -s_k:]
168
+ if prefix_mask is not None and (attention_mask.shape !=
169
+ prefix_mask.shape):
170
+ raise ValueError(
171
+ f'attention_mask shape={attention_mask.shape} ' +\
172
+ f'and prefix_mask shape={prefix_mask.shape} are not equal.'
173
+ )
174
+ min_val = torch.finfo(attn_bias.dtype).min
175
+ attn_bias = attn_bias.masked_fill(
176
+ ~attention_mask.view(-1, 1, 1, s_k), min_val)
177
+
178
+ return attn_bias, None
179
+
180
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor,
181
+ prefix_mask: torch.Tensor):
182
+ s_k, s_q = attn_bias.shape[-2:]
183
+ if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len):
184
+ raise ValueError(
185
+ 'attn_bias does not match the expected shape. ' +\
186
+ f'The last two dimensions should both be {self.config.max_length} ' +\
187
+ f'but are {s_k} and {s_q}.'
188
+ )
189
+ seq_len = prefix_mask.shape[-1]
190
+ if seq_len > self.config.max_seq_len:
191
+ raise ValueError(
192
+ f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
193
+ )
194
+
195
+ # select seq_len subset of attn mask
196
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
197
+
198
+ # Mix the causal max and the bidirectional mask to get the full
199
+ # allowable attention (i.e. full = not accounting for padding yet)
200
+ causal = torch.tril(
201
+ torch.ones((seq_len, seq_len),
202
+ dtype=torch.bool,
203
+ device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
204
+ prefix = prefix_mask.view(-1, 1, 1, seq_len)
205
+ cannot_attend = ~torch.logical_or(causal, prefix.bool())
206
+
207
+ min_val = torch.finfo(attn_bias.dtype).min
208
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
209
+
210
+ return attn_bias
211
+
212
+ def _apply_sequence_id(self, attn_bias: torch.Tensor,
213
+ sequence_id: torch.LongTensor):
214
+ seq_len = sequence_id.shape[-1]
215
+ if seq_len > self.config.max_seq_len:
216
+ raise ValueError(
217
+ f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
218
+ )
219
+
220
+ # select seq_len subset of attn mask
221
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
222
+
223
+ # Restrict attention to tokens that share the same value
224
+ # in sequence_id
225
+ cannot_attend = torch.logical_not(
226
+ torch.eq(sequence_id.view(-1, seq_len, 1),
227
+ sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
228
+ min_val = torch.finfo(attn_bias.dtype).min
229
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
230
+
231
+ return attn_bias
232
+
233
+ def forward(
234
+ self,
235
+ input_ids: torch.LongTensor,
236
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
237
+ attention_mask: Optional[torch.ByteTensor] = None,
238
+ prefix_mask: Optional[torch.ByteTensor] = None,
239
+ sequence_id: Optional[torch.LongTensor] = None,
240
+ return_dict: Optional[bool] = None,
241
+ output_attentions: Optional[bool] = None,
242
+ output_hidden_states: Optional[bool] = None,
243
+ use_cache: Optional[bool] = None):
244
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
245
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
246
+
247
+ # These args are passed in by keyword in huggingface's generate function
248
+ # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
249
+ # but have not yet been fully implemented in MosaicGPT
250
+ if not return_dict:
251
+ raise NotImplementedError(
252
+ 'return_dict False is not implemented yet for MosaicGPT')
253
+ if output_attentions:
254
+ raise NotImplementedError(
255
+ 'output_attentions is not implemented yet for MosaicGPT')
256
+
257
+ if attention_mask is not None and attention_mask[:, 0].sum(
258
+ ) != attention_mask.shape[0] and self.training:
259
+ raise NotImplementedError(
260
+ 'MosaicGPT does not support training with left padding.')
261
+
262
+ if self.prefix_lm and prefix_mask is None:
263
+ raise ValueError(
264
+ 'prefix_mask is a required argument when MosaicGPT is configured with prefix_lm=True.'
265
+ )
266
+
267
+ if self.training:
268
+ if self.attn_uses_sequence_id and sequence_id is None:
269
+ raise ValueError(
270
+ 'sequence_id is a required argument when MosaicGPT is configured with attn_uses_sequence_id=True ' +\
271
+ 'and the model is in train mode.'
272
+ )
273
+ elif (self.attn_uses_sequence_id is False) and (sequence_id
274
+ is not None):
275
+ warnings.warn(
276
+ 'MosaicGPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' +\
277
+ 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'
278
+ )
279
+
280
+ S = input_ids.size(1)
281
+
282
+ assert (
283
+ S <= self.config.max_seq_len
284
+ ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
285
+
286
+ tok_emb = self.transformer.wte(input_ids) # type: ignore
287
+ if self.alibi:
288
+ x = tok_emb
289
+ else:
290
+ past_position = 0
291
+ if past_key_values is not None:
292
+ if len(past_key_values) != self.config.n_layers:
293
+ raise ValueError(
294
+ f'past_key_values must provide a past_key_value for each attention ' +\
295
+ f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
296
+ )
297
+ # get the key tensor whose spec should be (batch, seq, dim), and
298
+ # collect the `seq`, so that the position embedding is shifted
299
+ past_position = past_key_values[0][0].size(1)
300
+
301
+ if S + past_position > self.config.max_seq_len:
302
+ raise ValueError(
303
+ f'Cannot forward input with past sequence length {past_position} and current sequence length '
304
+ f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.'
305
+ )
306
+ pos = torch.arange(past_position,
307
+ S + past_position,
308
+ dtype=torch.long,
309
+ device=input_ids.device).unsqueeze(0)
310
+ if attention_mask is not None:
311
+ # adjust the position indices to account for padding tokens
312
+ pos = torch.clamp(pos - torch.cumsum(
313
+ (~attention_mask).to(torch.int32), dim=1)[:,
314
+ past_position:],
315
+ min=0)
316
+
317
+ pos_emb = self.transformer.wpe(pos) # type: ignore
318
+ x = tok_emb + pos_emb
319
+
320
+ if self.embedding_fraction == 1:
321
+ x = self.transformer.emb_drop(x) # type: ignore
322
+ else:
323
+ # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
324
+ x_shrunk = (x * self.embedding_fraction) + (
325
+ x.detach() * (1 - self.embedding_fraction))
326
+ assert isinstance(self.transformer.emb_drop, nn.Module) # pyright
327
+ x = self.transformer.emb_drop(x_shrunk)
328
+
329
+ attn_bias, attention_mask = self._attn_bias(
330
+ device=x.device,
331
+ dtype=x.dtype,
332
+ attention_mask=attention_mask,
333
+ prefix_mask=prefix_mask,
334
+ sequence_id=sequence_id)
335
+
336
+ # initialize the past key values cache if it should be used
337
+ if use_cache and past_key_values is None:
338
+ past_key_values = [() for _ in range(self.config.n_layers)
339
+ ] # type: ignore
340
+
341
+ all_hidden_states = () if output_hidden_states else None
342
+ for b_idx, block in enumerate(self.transformer.blocks): # type: ignore
343
+ if output_hidden_states:
344
+ assert all_hidden_states is not None # pyright
345
+ all_hidden_states = all_hidden_states + (x,)
346
+ past_key_value = past_key_values[
347
+ b_idx] if past_key_values is not None else None
348
+ x, past_key_value = block(x,
349
+ past_key_value=past_key_value,
350
+ attn_bias=attn_bias,
351
+ attention_mask=attention_mask,
352
+ is_causal=self.is_causal)
353
+ if past_key_values is not None:
354
+ past_key_values[b_idx] = past_key_value
355
+
356
+ x = self.transformer.ln_f(x) # type: ignore
357
+
358
+ # output embedding weight tied to input embedding
359
+ assert isinstance(self.transformer.wte, nn.Module) # pyright
360
+ assert isinstance(self.transformer.wte.weight, torch.Tensor) # pyright
361
+ logits = F.linear(x, self.transformer.wte.weight, None)
362
+
363
+ if self.logit_scale is not None:
364
+ if self.logit_scale == 0:
365
+ warnings.warn(
366
+ f'Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.'
367
+ )
368
+ logits *= self.logit_scale
369
+
370
+ return CausalLMOutputWithPast(logits=logits,
371
+ past_key_values=past_key_values,
372
+ hidden_states=all_hidden_states)
373
+
374
+ # Param Initialization, needed for device='meta' fast initialization
375
+ def param_init_fn(self, module):
376
+ init_fn_name = self.config.param_init_fn
377
+ if self.config.verbose > 1:
378
+ warnings.warn(f'Using {init_fn_name} initialization.')
379
+ MODEL_INIT_REGISTRY[init_fn_name](module=module,
380
+ **self.config.to_dict())
381
+
382
+ # FSDP Wrap function
383
+ def fsdp_wrap_fn(self, module):
384
+ return isinstance(module, GPTBlock)
385
+
386
+ # Activation Checkpointing
387
+ def activation_checkpointing_fn(self, module):
388
+ return isinstance(module, GPTBlock)
389
+
390
+ def prepare_inputs_for_generation(self,
391
+ input_ids,
392
+ past_key_values=None,
393
+ inputs_embeds=None,
394
+ **kwargs):
395
+ if inputs_embeds is not None:
396
+ raise NotImplementedError(
397
+ 'inputs_embeds is not implemented for MosaicGPT yet')
398
+
399
+ attention_mask = kwargs['attention_mask'].bool()
400
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
401
+ raise NotImplementedError(
402
+ 'MosaicGPT does not support generation with right padding.')
403
+
404
+ if self.attn_uses_sequence_id and self.training:
405
+ sequence_id = torch.zeros_like(input_ids[:1])
406
+ else:
407
+ sequence_id = None
408
+
409
+ if past_key_values is not None:
410
+ input_ids = input_ids[:, -1].unsqueeze(-1)
411
+
412
+ if self.prefix_lm:
413
+ # Leverage a convenience of sequential generation!
414
+ prefix_mask = torch.ones_like(attention_mask)
415
+ # This requires that we're using the cache
416
+ if kwargs.get('use_cache') == False:
417
+ raise NotImplementedError(
418
+ 'MosaicGPT with prefix_lm=True does not support use_cache=False.'
419
+ )
420
+ else:
421
+ prefix_mask = None
422
+
423
+ return {
424
+ 'input_ids': input_ids,
425
+ 'attention_mask': attention_mask,
426
+ 'prefix_mask': prefix_mask,
427
+ 'sequence_id': sequence_id,
428
+ 'past_key_values': past_key_values,
429
+ 'use_cache': kwargs.get('use_cache', True),
430
+ }
431
+
432
+ @staticmethod
433
+ def _reorder_cache(past_key_values, beam_idx):
434
+ """Used by HuggingFace generate when using beam search with kv-caching.
435
+
436
+ See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
437
+ for an example in transformers.
438
+ """
439
+ reordered_past = []
440
+ for layer_past in past_key_values:
441
+ reordered_past += [
442
+ tuple(
443
+ past_state.index_select(0, beam_idx)
444
+ for past_state in layer_past)
445
+ ]
446
+ return reordered_past
param_init_fns.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import math
4
+ import warnings
5
+ from collections.abc import Sequence
6
+ from functools import partial
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ def torch_default_param_init_fn_(
14
+ module: nn.Module,
15
+ verbose: int = 0,
16
+ **kwargs,
17
+ ):
18
+ del kwargs # unused, just to capture any extra args from the config
19
+ if verbose > 1:
20
+ warnings.warn(
21
+ f"Initializing network using module's reset_parameters attribute")
22
+
23
+ if hasattr(module, 'reset_parameters'):
24
+ module.reset_parameters() # type: ignore
25
+
26
+
27
+ def fused_init_helper_(module: nn.Module, init_fn_):
28
+ # parameter initialization is often based on the parameters shape.
29
+ # If a layer is fused, initialization should be based on the shapes
30
+ # of the original tensor instead of the shape of the fused tensor.
31
+ # Layers which are fused should have the _fused attibute defined.
32
+ # The first element of _fused is the dimension along which the tensor is fused.
33
+ # This is followed by an iterable of split indices."
34
+
35
+ _fused = getattr(module, '_fused', None)
36
+
37
+ if _fused is None:
38
+ raise RuntimeError(f'Internal logic error')
39
+
40
+ dim, splits = _fused
41
+ splits = (0, *splits, module.weight.size(dim)) # type: ignore
42
+ for s, e in zip(splits[:-1], splits[1:]):
43
+ slice_indices = [slice(None)] * module.weight.ndim # type: ignore
44
+ slice_indices[dim] = slice(s, e)
45
+ init_fn_(module.weight[slice_indices]) # type: ignore
46
+
47
+
48
+ def generic_param_init_fn_(
49
+ module: nn.Module,
50
+ init_fn_,
51
+ n_layers: int,
52
+ d_model: Optional[int] = None,
53
+ init_div_is_residual: Union[int, float, str, bool] = True,
54
+ emb_init_std: Optional[float] = None,
55
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
56
+ verbose: int = 0,
57
+ **kwargs,
58
+ ):
59
+ del kwargs # unused, just to capture any extra args from the config
60
+ if verbose > 1:
61
+ warnings.warn(
62
+ f'If model has bias parameters they are initialized to 0.')
63
+
64
+ # enable user to divide _is_residual weights by
65
+ # a value which defaults to math.sqrt(2 * cfg.n_layers)
66
+ init_div_is_residual = init_div_is_residual
67
+
68
+ if init_div_is_residual is False:
69
+ # not used, for pyright
70
+ div_is_residual = 1.0
71
+ elif init_div_is_residual is True:
72
+ div_is_residual = math.sqrt(2 * n_layers)
73
+ elif isinstance(init_div_is_residual, float) or isinstance(
74
+ init_div_is_residual, int):
75
+ div_is_residual = init_div_is_residual
76
+ elif isinstance(init_div_is_residual,
77
+ str) and init_div_is_residual.isnumeric():
78
+ # do not trust YAML parsing to always convert numbers to numbers
79
+ div_is_residual = float(init_div_is_residual)
80
+ else:
81
+ # not used, for pyright
82
+ div_is_residual = 1.0
83
+ raise ValueError(
84
+ f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
85
+ )
86
+
87
+ if init_div_is_residual is not False:
88
+ if verbose > 1:
89
+ warnings.warn(
90
+ f'Initializing _is_residual layers then dividing them by {div_is_residual}.' +\
91
+ f'set `init_div_is_residual: false` in model config to disable this.'
92
+ )
93
+
94
+ if isinstance(module, nn.Linear):
95
+ # Linear
96
+ if hasattr(module, '_fused'):
97
+ fused_init_helper_(module, init_fn_)
98
+ else:
99
+ init_fn_(module.weight)
100
+ if module.bias is not None:
101
+ torch.nn.init.zeros_(module.bias)
102
+
103
+ if init_div_is_residual is not False and getattr(
104
+ module, '_is_residual', False):
105
+ with torch.no_grad():
106
+ module.weight.div_(div_is_residual)
107
+
108
+ elif isinstance(module, nn.Embedding):
109
+ # Embedding
110
+ if emb_init_std is not None:
111
+ std = emb_init_std
112
+ if std == 0:
113
+ warnings.warn(f'Embedding layer initialized to 0.')
114
+ emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
115
+ if verbose > 1:
116
+ warnings.warn(
117
+ f'Embedding layer initialized using normal distribution with mean=0 and {std=}.'
118
+ )
119
+ elif emb_init_uniform_lim is not None:
120
+ lim = emb_init_uniform_lim
121
+ if isinstance(lim, Sequence):
122
+ if len(lim) > 2:
123
+ raise ValueError(
124
+ f'Uniform init requires a min and a max limit. User input: {lim}.'
125
+ )
126
+ if lim[0] == lim[1]:
127
+ warnings.warn(f'Embedding layer initialized to {lim[0]}.')
128
+ else:
129
+ if lim == 0:
130
+ warnings.warn(f'Embedding layer initialized to 0.')
131
+ lim = [-lim, lim]
132
+ a, b = lim
133
+ emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
134
+ if verbose > 1:
135
+ warnings.warn(
136
+ f'Embedding layer initialized using uniform distribution in range {lim}.'
137
+ )
138
+ else:
139
+ emb_init_fn_ = init_fn_
140
+
141
+ emb_init_fn_(module.weight)
142
+
143
+ elif isinstance(module, nn.LayerNorm):
144
+ # LayerNorm
145
+ if verbose > 1:
146
+ warnings.warn(
147
+ f'LayerNorm gamma weights are set to 1. If the layer has a bias it is initialized to 0.'
148
+ )
149
+ torch.nn.init.ones_(module.weight)
150
+ if module.bias is not None:
151
+ torch.nn.init.zeros_(module.bias)
152
+
153
+ elif isinstance(module, nn.MultiheadAttention):
154
+ # torch's MultiheadAttention
155
+ if module._qkv_same_embed_dim:
156
+ assert module.in_proj_weight is not None
157
+ assert module.q_proj_weight is None and module.k_proj_weight is None and module.v_proj_weight is None
158
+ assert d_model is not None
159
+ # in_proj_weight is actually 3 layers and should be split up for width based init
160
+ _d = d_model
161
+ splits = (0, _d, 2 * _d, 3 * _d)
162
+ for s, e in zip(splits[:-1], splits[1:]):
163
+ init_fn_(module.in_proj_weight[s:e])
164
+ else:
165
+ assert module.q_proj_weight is not None and module.k_proj_weight is not None and module.v_proj_weight is not None
166
+ assert module.in_proj_weight is None
167
+ init_fn_(module.q_proj_weight)
168
+ init_fn_(module.k_proj_weight)
169
+ init_fn_(module.v_proj_weight)
170
+
171
+ # bias
172
+ if module.in_proj_bias is not None:
173
+ torch.nn.init.zeros_(module.in_proj_bias)
174
+ if module.bias_k is not None:
175
+ torch.nn.init.zeros_(module.bias_k)
176
+ if module.bias_v is not None:
177
+ torch.nn.init.zeros_(module.bias_v)
178
+
179
+ # out proj
180
+ init_fn_(module.out_proj.weight)
181
+ if init_div_is_residual is not False and getattr(
182
+ module.out_proj, '_is_residual', False):
183
+ with torch.no_grad():
184
+ module.out_proj.weight.div_(div_is_residual)
185
+ if module.out_proj.bias is not None:
186
+ torch.nn.init.zeros_(module.out_proj.bias)
187
+
188
+ else:
189
+ for _ in module.parameters(recurse=False):
190
+ # raise error if uninitialized module has any parameters
191
+ raise NotImplementedError(
192
+ f'{module.__class__.__name__} parameters are not initialized by param_init_fn.'
193
+ )
194
+
195
+
196
+ def _normal_init_(std, mean=0.0):
197
+ return partial(torch.nn.init.normal_, mean=mean, std=std)
198
+
199
+
200
+ def _normal_param_init_fn_(
201
+ module: nn.Module,
202
+ std: float,
203
+ n_layers: int,
204
+ d_model: Optional[int] = None,
205
+ init_div_is_residual: Union[int, float, str, bool] = True,
206
+ emb_init_std: Optional[float] = None,
207
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
208
+ verbose: int = 0,
209
+ **kwargs,
210
+ ):
211
+ del kwargs # unused, just to capture any extra args from the config
212
+ init_fn_ = _normal_init_(std=std)
213
+
214
+ if verbose > 1:
215
+ warnings.warn(
216
+ f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
217
+
218
+ generic_param_init_fn_(
219
+ module=module,
220
+ init_fn_=init_fn_,
221
+ d_model=d_model,
222
+ n_layers=n_layers,
223
+ init_div_is_residual=init_div_is_residual,
224
+ emb_init_std=emb_init_std,
225
+ emb_init_uniform_lim=emb_init_uniform_lim,
226
+ verbose=verbose,
227
+ )
228
+
229
+
230
+ def baseline_param_init_fn_(
231
+ module: nn.Module,
232
+ init_std: float,
233
+ n_layers: int,
234
+ d_model: Optional[int] = None,
235
+ init_div_is_residual: Union[int, float, str, bool] = True,
236
+ emb_init_std: Optional[float] = None,
237
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
238
+ verbose: int = 0,
239
+ **kwargs,
240
+ ):
241
+ del kwargs # unused, just to capture any extra args from the config
242
+ if init_std is None:
243
+ raise ValueError(
244
+ 'You must set model.init_std to a float value to use the default initialization scheme.'
245
+ )
246
+ _normal_param_init_fn_(
247
+ module=module,
248
+ std=init_std,
249
+ d_model=d_model,
250
+ n_layers=n_layers,
251
+ init_div_is_residual=init_div_is_residual,
252
+ emb_init_std=emb_init_std,
253
+ emb_init_uniform_lim=emb_init_uniform_lim,
254
+ verbose=verbose,
255
+ )
256
+
257
+
258
+ def small_param_init_fn_(
259
+ module: nn.Module,
260
+ n_layers: int,
261
+ d_model: int,
262
+ init_div_is_residual: Union[int, float, str, bool] = True,
263
+ emb_init_std: Optional[float] = None,
264
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
265
+ verbose: int = 0,
266
+ **kwargs,
267
+ ):
268
+ del kwargs # unused, just to capture any extra args from the config
269
+ # very close to kaiming normal
270
+ # from Transformers without Tears (2019) - Nguyen & Salazar
271
+ std = math.sqrt(2 / (5 * d_model))
272
+ _normal_param_init_fn_(
273
+ module=module,
274
+ std=std,
275
+ d_model=d_model,
276
+ n_layers=n_layers,
277
+ init_div_is_residual=init_div_is_residual,
278
+ emb_init_std=emb_init_std,
279
+ emb_init_uniform_lim=emb_init_uniform_lim,
280
+ verbose=verbose,
281
+ )
282
+
283
+
284
+ def neox_param_init_fn_(
285
+ module: nn.Module,
286
+ n_layers: int,
287
+ d_model: int,
288
+ emb_init_std: Optional[float] = None,
289
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
290
+ verbose: int = 0,
291
+ **kwargs,
292
+ ):
293
+ """From section 2.3.1 of GPT-NeoX-20B:
294
+
295
+ An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
296
+ see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
297
+ and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
298
+ """
299
+ del kwargs # unused, just to capture any extra args from the config
300
+ residual_div = n_layers / math.sqrt(10) # small std / wang std
301
+
302
+ if verbose > 1:
303
+ warnings.warn(f'setting init_div_is_residual to {residual_div}')
304
+
305
+ small_param_init_fn_(
306
+ module=module,
307
+ d_model=d_model,
308
+ n_layers=n_layers,
309
+ init_div_is_residual=residual_div,
310
+ emb_init_std=emb_init_std,
311
+ emb_init_uniform_lim=emb_init_uniform_lim,
312
+ verbose=verbose,
313
+ )
314
+
315
+
316
+ def kaiming_uniform_param_init_fn_(
317
+ module: nn.Module,
318
+ n_layers: int,
319
+ d_model: Optional[int] = None,
320
+ init_div_is_residual: Union[int, float, str, bool] = True,
321
+ emb_init_std: Optional[float] = None,
322
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
323
+ init_gain: float = 0,
324
+ fan_mode: str = 'fan_in',
325
+ init_nonlinearity: str = 'leaky_relu',
326
+ verbose: int = 0,
327
+ **kwargs,
328
+ ):
329
+ del kwargs # unused, just to capture any extra args from the config
330
+
331
+ if verbose > 1:
332
+ warnings.warn(
333
+ f'Using nn.init.kaiming_uniform_ init fn with parameters: ' +\
334
+ f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
335
+ )
336
+
337
+ kaiming_uniform_ = partial(nn.init.kaiming_uniform_,
338
+ a=init_gain,
339
+ mode=fan_mode,
340
+ nonlinearity=init_nonlinearity)
341
+
342
+ generic_param_init_fn_(
343
+ module=module,
344
+ init_fn_=kaiming_uniform_,
345
+ d_model=d_model,
346
+ n_layers=n_layers,
347
+ init_div_is_residual=init_div_is_residual,
348
+ emb_init_std=emb_init_std,
349
+ emb_init_uniform_lim=emb_init_uniform_lim,
350
+ verbose=verbose,
351
+ )
352
+
353
+
354
+ def kaiming_normal_param_init_fn_(
355
+ module: nn.Module,
356
+ n_layers: int,
357
+ d_model: Optional[int] = None,
358
+ init_div_is_residual: Union[int, float, str, bool] = True,
359
+ emb_init_std: Optional[float] = None,
360
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
361
+ init_gain: float = 0,
362
+ fan_mode: str = 'fan_in',
363
+ init_nonlinearity: str = 'leaky_relu',
364
+ verbose: int = 0,
365
+ **kwargs,
366
+ ):
367
+ del kwargs # unused, just to capture any extra args from the config
368
+
369
+ if verbose > 1:
370
+ warnings.warn(
371
+ f'Using nn.init.kaiming_normal_ init fn with parameters: ' +\
372
+ f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
373
+ )
374
+
375
+ kaiming_normal_ = partial(torch.nn.init.kaiming_normal_,
376
+ a=init_gain,
377
+ mode=fan_mode,
378
+ nonlinearity=init_nonlinearity)
379
+
380
+ generic_param_init_fn_(
381
+ module=module,
382
+ init_fn_=kaiming_normal_,
383
+ d_model=d_model,
384
+ n_layers=n_layers,
385
+ init_div_is_residual=init_div_is_residual,
386
+ emb_init_std=emb_init_std,
387
+ emb_init_uniform_lim=emb_init_uniform_lim,
388
+ verbose=verbose,
389
+ )
390
+
391
+
392
+ def xavier_uniform_param_init_fn_(
393
+ module: nn.Module,
394
+ n_layers: int,
395
+ d_model: Optional[int] = None,
396
+ init_div_is_residual: Union[int, float, str, bool] = True,
397
+ emb_init_std: Optional[float] = None,
398
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
399
+ init_gain: float = 0,
400
+ verbose: int = 0,
401
+ **kwargs,
402
+ ):
403
+ del kwargs # unused, just to capture any extra args from the config
404
+ xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
405
+
406
+ if verbose > 1:
407
+ warnings.warn(
408
+ f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' +\
409
+ f'gain={init_gain}'
410
+ )
411
+
412
+ generic_param_init_fn_(
413
+ module=module,
414
+ init_fn_=xavier_uniform_,
415
+ d_model=d_model,
416
+ n_layers=n_layers,
417
+ init_div_is_residual=init_div_is_residual,
418
+ emb_init_std=emb_init_std,
419
+ emb_init_uniform_lim=emb_init_uniform_lim,
420
+ verbose=verbose,
421
+ )
422
+
423
+
424
+ def xavier_normal_param_init_fn_(
425
+ module: nn.Module,
426
+ n_layers: int,
427
+ d_model: Optional[int] = None,
428
+ init_div_is_residual: Union[int, float, str, bool] = True,
429
+ emb_init_std: Optional[float] = None,
430
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
431
+ init_gain: float = 0,
432
+ verbose: int = 0,
433
+ **kwargs,
434
+ ):
435
+ xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
436
+
437
+ if verbose > 1:
438
+ warnings.warn(
439
+ f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' +\
440
+ f'gain={init_gain}'
441
+ )
442
+
443
+ generic_param_init_fn_(
444
+ module=module,
445
+ init_fn_=xavier_normal_,
446
+ d_model=d_model,
447
+ n_layers=n_layers,
448
+ init_div_is_residual=init_div_is_residual,
449
+ emb_init_std=emb_init_std,
450
+ emb_init_uniform_lim=emb_init_uniform_lim,
451
+ verbose=verbose,
452
+ )
453
+
454
+
455
+ MODEL_INIT_REGISTRY = {
456
+ 'default_': torch_default_param_init_fn_,
457
+ 'baseline_': baseline_param_init_fn_,
458
+ 'kaiming_uniform_': kaiming_uniform_param_init_fn_,
459
+ 'kaiming_normal_': kaiming_normal_param_init_fn_,
460
+ 'neox_init_': neox_param_init_fn_,
461
+ 'small_init_': small_param_init_fn_,
462
+ 'xavier_uniform_': xavier_uniform_param_init_fn_,
463
+ 'xavier_normal_': xavier_normal_param_init_fn_,
464
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f195ac04c4300f0c0cf51f97d1e77580353699d0f56285072e38f555dbd68c1
3
+ size 5245834073