Zerx966 commited on
Commit
3ef28b3
1 Parent(s): 6947a9d

Upload 10 files

Browse files
Files changed (10) hide show
  1. .gitignore +5 -0
  2. __init__.py +1 -0
  3. hf_utils.py +15 -0
  4. mamba_block.py +354 -0
  5. mamba_config.py +86 -0
  6. mamba_model.py +183 -0
  7. mlp.py +43 -0
  8. setup.py +159 -0
  9. switch_mlp.py +91 -0
  10. utils.py +82 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *__pycache__/
2
+ *.egg-info/
3
+ build/
4
+ **.so
5
+ **.ipynb
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
hf_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import transformers
4
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
5
+ from transformers.utils.hub import cached_file
6
+
7
+
8
+ def load_config_hf(model_name):
9
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
10
+ return json.load(open(resolved_archive_file))
11
+
12
+
13
+ def load_state_dict_hf(model_name, device="cpu"):
14
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
15
+ return torch.load(resolved_archive_file, map_location=device)
mamba_block.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from typing import Optional, Union
4
+ import re
5
+ from contextlib import nullcontext
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass
8
+ import functools
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch import Tensor
15
+ from einops import rearrange, repeat
16
+
17
+ try:
18
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
19
+ except ImportError:
20
+ causal_conv1d_fn, causal_conv1d_update = None, None
21
+
22
+ try:
23
+ from ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
24
+ except ImportError:
25
+ selective_scan_fn, mamba_inner_fn = None, None
26
+
27
+ try:
28
+ from ops.triton.selective_state_update import selective_state_update
29
+ except ImportError:
30
+ selective_state_update = None
31
+
32
+ try:
33
+ from ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
34
+ except ImportError:
35
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
36
+
37
+ from mamba_layer import MambaLayer
38
+ from mamba_config import MambaConfig
39
+ from mlp import MLP
40
+ from switch_mlp import SwitchMLP
41
+
42
+
43
+ class MambaBlock(nn.Module):
44
+ def __init__(
45
+ self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
46
+ ):
47
+ super().__init__()
48
+ self.config = config
49
+ self.residual_in_fp32 = residual_in_fp32
50
+ self.fused_add_norm = fused_add_norm
51
+ self.mixer = mixer_cls(config)
52
+
53
+ if not config.rms_norm:
54
+ self.norm = norm_cls
55
+ else:
56
+ self.norm = norm_cls(config.hidden_size)
57
+
58
+ if self.fused_add_norm:
59
+ assert RMSNorm is not None, "RMSNorm import fails"
60
+ assert isinstance(
61
+ self.norm, (nn.LayerNorm, RMSNorm)
62
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
63
+ if moe_cls is not None:
64
+ self.moe = moe_cls(config)
65
+ else:
66
+ self.moe = None
67
+
68
+ def forward(
69
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
70
+ ):
71
+
72
+ if not self.fused_add_norm:
73
+ residual = (hidden_states + residual) if residual is not None else hidden_states
74
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
75
+ if self.residual_in_fp32:
76
+ residual = residual.to(torch.float32)
77
+ else:
78
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
79
+ hidden_states, residual = fused_add_norm_fn(
80
+ hidden_states,
81
+ self.norm.weight,
82
+ self.norm.bias,
83
+ residual=residual,
84
+ prenorm=True,
85
+ residual_in_fp32=self.residual_in_fp32,
86
+ eps=self.norm.eps,
87
+ )
88
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
89
+ return hidden_states , residual
90
+
91
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
92
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
93
+
94
+ class MambaBlockParallelMoe(nn.Module):
95
+ def __init__(
96
+ self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, norm_moe=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
97
+ ):
98
+
99
+ super().__init__()
100
+ self.config = config
101
+ self.residual_in_fp32 = residual_in_fp32
102
+ self.fused_add_norm = fused_add_norm
103
+ self.mixer = mixer_cls(config)
104
+ if not config.rms_norm:
105
+ self.norm = norm_cls
106
+ self.norm_moe = norm_moe
107
+ else:
108
+ self.norm = norm_cls(config.hidden_size)
109
+ self.norm_moe = norm_moe(config.hidden_size)
110
+ if self.fused_add_norm:
111
+ assert RMSNorm is not None, "RMSNorm import fails"
112
+ assert isinstance(
113
+ self.norm, (nn.LayerNorm, RMSNorm)
114
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
115
+ assert isinstance(
116
+ self.norm_moe, (nn.LayerNorm, RMSNorm)
117
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
118
+ if moe_cls is not None:
119
+ self.moe = moe_cls(config)
120
+ else:
121
+ self.moe = None
122
+
123
+ def forward(
124
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
125
+ ):
126
+
127
+ if not self.fused_add_norm:
128
+ residual = (hidden_states + residual) if residual is not None else hidden_states
129
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
130
+ hidden_states_moe = self.norm_moe(residual.to(dtype=self.norm.weight.dtype))
131
+ if self.residual_in_fp32:
132
+ residual = residual.to(torch.float32)
133
+ else:
134
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
135
+ hidden_states, residual = fused_add_norm_fn(
136
+ hidden_states,
137
+ self.norm.weight,
138
+ self.norm.bias,
139
+ residual=residual,
140
+ prenorm=True,
141
+ residual_in_fp32=self.residual_in_fp32,
142
+ eps=self.norm.eps,
143
+ )
144
+ hidden_states_moe, _ = fused_add_norm_fn(
145
+ hidden_states,
146
+ self.norm_moe.weight,
147
+ self.norm_moe.bias,
148
+ residual=residual,
149
+ prenorm=True,
150
+ residual_in_fp32=self.residual_in_fp32,
151
+ eps=self.norm_moe.eps,
152
+ )
153
+
154
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
155
+
156
+ hidden_states_moe = self.moe(hidden_states_moe)
157
+ hidden_states += hidden_states_moe
158
+ return hidden_states , residual
159
+
160
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
161
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
162
+
163
+
164
+ class MoEBlock(nn.Module):
165
+ def __init__(
166
+ self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
167
+ ):
168
+
169
+ super().__init__()
170
+ self.config = config
171
+ self.residual_in_fp32 = residual_in_fp32
172
+ self.fused_add_norm = fused_add_norm
173
+ self.mixer = mixer_cls(config)
174
+ if not config.rms_norm:
175
+ self.norm = norm_cls
176
+ else:
177
+ self.norm = norm_cls(config.hidden_size)
178
+ if self.fused_add_norm:
179
+ assert RMSNorm is not None, "RMSNorm import fails"
180
+ assert isinstance(
181
+ self.norm, (nn.LayerNorm, RMSNorm)
182
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
183
+ if moe_cls is not None:
184
+ self.moe = moe_cls(config)
185
+ else:
186
+ self.moe = None
187
+
188
+ def forward(
189
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
190
+ ):
191
+ if not self.fused_add_norm:
192
+ residual = (hidden_states + residual) if residual is not None else hidden_states
193
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
194
+ if self.residual_in_fp32:
195
+ residual = residual.to(torch.float32)
196
+ else:
197
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
198
+ hidden_states, residual = fused_add_norm_fn(
199
+ hidden_states,
200
+ self.norm.weight,
201
+ self.norm.bias,
202
+ residual=residual,
203
+ prenorm=True,
204
+ residual_in_fp32=self.residual_in_fp32,
205
+ eps=self.norm.eps,
206
+ )
207
+ hidden_states = self.mixer(hidden_states)
208
+ return hidden_states , residual
209
+
210
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
211
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
212
+
213
+
214
+ def create_block(config, layer_idx):
215
+
216
+ if config.rms_norm:
217
+ norm_cls = partial(RMSNorm, eps=config.layernorm_epsilon)
218
+ else:
219
+ norm_cls = partial(nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon)
220
+
221
+ if (not config.mamba_moe_layers) or config.mamba_moe_layers[layer_idx-1][0] == 'r':
222
+ if (not config.mamba_moe_layers) or len(config.mamba_moe_layers[layer_idx-1]) == 1:
223
+ mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
224
+ block = MambaBlock(
225
+ config,
226
+ mixer_cls=mixer_cls,
227
+ norm_cls=norm_cls,
228
+ fused_add_norm=config.fused_add_norm,
229
+ residual_in_fp32=config.residual_in_fp32,
230
+ )
231
+ else:
232
+ if config.mamba_moe_layers[layer_idx-1][1] == '1':
233
+ if config.rms_norm:
234
+ norm_moe = partial(RMSNorm, eps=config.layernorm_epsilon)
235
+ else:
236
+ norm_moe = partial(
237
+ nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon
238
+ )
239
+ mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
240
+ moe_cls = partial(MLP, layer_idx=layer_idx)
241
+ block = MambaBlockParallelMoe(
242
+ config,
243
+ mixer_cls=mixer_cls,
244
+ moe_cls=moe_cls,
245
+ norm_cls=norm_cls,
246
+ norm_moe=norm_moe,
247
+ fused_add_norm=config.fused_add_norm,
248
+ residual_in_fp32=config.residual_in_fp32,
249
+ )
250
+ else:
251
+ if config.rms_norm:
252
+ norm_moe = partial(RMSNorm, eps=config.layernorm_epsilon)
253
+ else:
254
+ norm_moe = partial(
255
+ nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon
256
+ )
257
+ mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
258
+ moe_cls = partial(SwitchMLP, layer_idx=layer_idx)
259
+ block = MambaBlockParallelMoe(
260
+ config,
261
+ mixer_cls=mixer_cls,
262
+ moe_cls=moe_cls,
263
+ norm_cls=norm_cls,
264
+ norm_moe=norm_moe,
265
+ fused_add_norm=config.fused_add_norm,
266
+ residual_in_fp32=config.residual_in_fp32,
267
+ )
268
+ else:
269
+ if config.mamba_moe_layers[layer_idx-1][0] == '1':
270
+ mixer_cls = partial(MLP, layer_idx=layer_idx)
271
+ block = MoEBlock(
272
+ config,
273
+ mixer_cls=mixer_cls,
274
+ norm_cls=norm_cls,
275
+ fused_add_norm=config.fused_add_norm,
276
+ residual_in_fp32=config.residual_in_fp32,
277
+ )
278
+ else:
279
+ mixer_cls = partial(SwitchMLP, layer_idx=layer_idx)
280
+ block = MoEBlock(
281
+ config,
282
+ mixer_cls=mixer_cls,
283
+ norm_cls=norm_cls,
284
+ fused_add_norm=config.fused_add_norm,
285
+ residual_in_fp32=config.residual_in_fp32,
286
+ )
287
+ block.layer_idx = layer_idx
288
+ return block
289
+
290
+ class MambaDecoder(nn.Module):
291
+ """Class wrapping a decoder stack of mamba blocks."""
292
+
293
+ def __init__(
294
+ self,
295
+ config: MambaConfig,
296
+ post_layer_norm=True,
297
+ pre_process=True,
298
+ post_process=True,
299
+ ):
300
+ super().__init__()
301
+
302
+ self.config: MambaConfig = config
303
+ self.post_layer_norm = post_layer_norm
304
+ self.pre_process = pre_process
305
+ self.post_process = post_process
306
+ self.norm_cls = partial(nn.LayerNorm, eps=self.config.layernorm_epsilon)
307
+
308
+ self._build_layers()
309
+
310
+ def _build_layers(self):
311
+
312
+ num_layers_to_build = self.config.num_layers
313
+ # build the actual mamba layers
314
+ self.layers = torch.nn.ModuleList([create_block(self.config, i + 1) for i in range(num_layers_to_build)])
315
+
316
+ if self.post_process and self.post_layer_norm:
317
+ # Final layer norm before output.
318
+ self.final_layernorm = self.norm_cls(self.config.hidden_size, bias = True)
319
+
320
+ def _get_layer(self, layer_number):
321
+ return self.layers[layer_number]
322
+
323
+ def forward(self, hidden_states, residual = None, inference_params=None):
324
+
325
+ if not self.pre_process:
326
+ # See set_input_tensor()
327
+ hidden_states = self.input_tensor
328
+
329
+ residual = None
330
+ for i,layer in enumerate(self.layers):
331
+ hidden_states, residual = layer(
332
+ hidden_states=hidden_states,
333
+ residual = residual,
334
+ inference_params=inference_params,
335
+ )
336
+
337
+ # Final layer norm.
338
+ if self.post_process and self.post_layer_norm:
339
+ if not self.config.fused_add_norm:
340
+ residual = (hidden_states + residual) if residual is not None else hidden_states
341
+ hidden_states = self.final_layernorm(residual.to(dtype=self.final_layernorm.weight.dtype))
342
+ else:
343
+ # Set prenorm=False here since we don't need the residual
344
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.final_layernorm, RMSNorm) else layer_norm_fn
345
+ hidden_states = fused_add_norm_fn(
346
+ hidden_states,
347
+ self.final_layernorm.weight,
348
+ self.final_layernorm.bias,
349
+ eps=self.final_layernorm.eps,
350
+ residual=residual,
351
+ prenorm=False,
352
+ residual_in_fp32=self.residual_in_fp32,
353
+ )
354
+ return hidden_states
mamba_config.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Callable
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from utils import init_method_normal, scaled_init_method_normal
6
+
7
+
8
+ @dataclass
9
+ class MambaConfig():
10
+ base_model_type: str = "mamba"
11
+ num_layers: int = 0
12
+ hidden_size: int = 0
13
+ state_size: int = 0
14
+ vocab_size: int = 50000
15
+ expansion_factor: int = 2
16
+ conv_dimension: int = 0
17
+ conv_bias: bool = True
18
+ bias: bool = True
19
+ use_fast_path: bool = True
20
+ dt_rank: str = "auto"
21
+ dt_min: float = 0.001
22
+ dt_max: float = 0.1
23
+ dt_init: str = "random"
24
+ dt_scale: float = 1.0
25
+ dt_init_floor: float = 1e-4
26
+ rms_norm: bool = True
27
+ fused_add_norm: bool = False
28
+ residual_in_fp32: bool = True
29
+ hidden_dropout: float = 0.0
30
+ ffn_hidden_size: int = None
31
+ gated_linear_unit: bool = False
32
+ mamba_moe_layers: str = ""
33
+ routing_mode: str = "sinkhorn"
34
+ device: str = "cuda"
35
+ fp32_residual_connection: bool = False
36
+ layernorm_epsilon: float = 1e-5
37
+ layernorm_zero_centered_gamma: bool = False
38
+ add_bias_linear: bool = True
39
+ activation_func: Callable = F.gelu
40
+ num_moe_experts: int = None
41
+
42
+ # initialization
43
+ init_method: Callable = None
44
+ output_layer_init_method: Callable = None
45
+ init_method_std: float = 0.02
46
+
47
+ # mixed-precision
48
+ apply_query_key_layer_scaling: bool = True
49
+ attention_softmax_in_fp32: bool = True
50
+
51
+ # fusion
52
+ gated_linear_unit: bool = False
53
+ bias_gelu_fusion: bool = False
54
+ persist_layer_norm: bool = False
55
+ bias_dropout_fusion: bool = False
56
+
57
+
58
+ def __post_init__(self):
59
+ """ Python dataclass method that is used to modify attributes after initialization.
60
+ See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
61
+ """
62
+ if self.apply_query_key_layer_scaling:
63
+ self.attention_softmax_in_fp32 = True
64
+
65
+ if self.ffn_hidden_size is None:
66
+ self.ffn_hidden_size = 4 * self.hidden_size
67
+
68
+ if self.apply_query_key_layer_scaling:
69
+ self.attention_softmax_in_fp32 = True
70
+
71
+ if self.bias_gelu_fusion:
72
+ if not self.add_bias_linear:
73
+ raise ValueError(
74
+ "When bias_gelu_fusion is True, add_bias_linear must also be True."
75
+ )
76
+
77
+ if self.activation_func != F.gelu:
78
+ raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.')
79
+
80
+ if self.init_method is None:
81
+ self.init_method = init_method_normal(self.init_method_std)
82
+
83
+ if self.output_layer_init_method is None:
84
+ self.output_layer_init_method = scaled_init_method_normal(
85
+ self.init_method_std, self.num_layers
86
+ )
mamba_model.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Literal, Optional, Union
3
+ import functools
4
+ from functools import partial
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor
8
+ import math
9
+ import os
10
+ from mamba_block import MambaBlock, MambaDecoder
11
+ from mamba_config import MambaConfig
12
+ from hf_utils import *
13
+ import os, json
14
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
15
+ from transformers.utils.hub import cached_file
16
+
17
+
18
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
19
+ def _init_weights(
20
+ module,
21
+ n_layer,
22
+ initializer_range=0.02, # Now only used for embedding layer.
23
+ rescale_prenorm_residual=True,
24
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
25
+ ):
26
+ if isinstance(module, nn.Linear):
27
+ if module.bias is not None:
28
+ if not getattr(module.bias, "_no_reinit", False):
29
+ nn.init.zeros_(module.bias)
30
+ elif isinstance(module, nn.Embedding):
31
+ nn.init.normal_(module.weight, std=initializer_range)
32
+
33
+ if rescale_prenorm_residual:
34
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
35
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
36
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
37
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
38
+ #
39
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
40
+ for name, p in module.named_parameters():
41
+ if name in ["out_proj.weight", "fc2.weight"]:
42
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
43
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
44
+ # We need to reinit p since this code could be called multiple times
45
+ # Having just p *= scale would repeatedly scale it down
46
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
47
+ with torch.no_grad():
48
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
49
+
50
+
51
+ class MambaModel(nn.Module):
52
+ def __init__(
53
+ self,
54
+ config: MambaConfig,
55
+ max_sequence_length: int,
56
+ pre_process: bool = True,
57
+ post_process: bool = True,
58
+ fp16_lm_cross_entropy: bool = False,
59
+ parallel_output: bool = True,
60
+ share_embeddings_and_output_weights: bool = True,
61
+ initializer_cfg = None,
62
+ ) -> None:
63
+ super().__init__()
64
+
65
+ self.config: MambaConfig = config
66
+ self.max_sequence_length = max_sequence_length
67
+ self.pre_process = pre_process
68
+ self.post_process = post_process
69
+ self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
70
+ self.parallel_output = parallel_output
71
+ self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
72
+
73
+ if self.pre_process:
74
+ self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
75
+
76
+
77
+ self.decoder = MambaDecoder(
78
+ config = self.config,
79
+ pre_process = self.pre_process,
80
+ post_process = self.post_process,
81
+ )
82
+
83
+ if post_process:
84
+ self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias = self.config.add_bias_linear)
85
+ if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process):
86
+ self.initialize_last_stage_with_word_embeddings()
87
+
88
+ # apply weight initialization
89
+ self.apply(
90
+ partial(
91
+ _init_weights,
92
+ n_layer=self.config.num_layers,
93
+ **(initializer_cfg if initializer_cfg is not None else {}),
94
+ )
95
+ )
96
+
97
+ def initialize_last_stage_with_word_embeddings(self):
98
+ with torch.no_grad():
99
+ self.output_layer.weight = self.embedding.weight
100
+
101
+ def forward(
102
+ self,
103
+ input_ids,
104
+ position_ids = None,
105
+ decoder_input: Tensor = None,
106
+ labels: Tensor = None,
107
+ inference_params=None,
108
+ ) -> Tensor:
109
+ if decoder_input is not None:
110
+ pass
111
+ elif self.pre_process:
112
+ decoder_input = self.embedding(input_ids)
113
+ else:
114
+ decoder_input = None
115
+
116
+ hidden_states = self.decoder(
117
+ hidden_states=decoder_input,
118
+ residual=None,
119
+ inference_params=inference_params,
120
+ )
121
+
122
+ if not self.post_process:
123
+ return hidden_states
124
+
125
+ logits = self.output_layer(hidden_states)
126
+
127
+ return logits.contiguous()
128
+
129
+ @classmethod
130
+ def from_pretrained(cls, pretrained_model_name = None, checkpoint_name=None, config_name=None, **kwargs):
131
+ if pretrained_model_name is not None:
132
+ json_config = load_config_hf(pretrained_model_name)
133
+ loaded = load_state_dict_hf(pretrained_model_name)
134
+ elif checkpoint_name is not None and config_name is not None:
135
+ with open(config_name, 'r') as f:
136
+ jsonstr = f.read()
137
+ json_config = json.loads(jsonstr)
138
+ loaded = torch.load(checkpoint_name, map_location='cpu')
139
+ else:
140
+ return
141
+ model_state_dict = loaded["model"]
142
+
143
+ config = MambaConfig(
144
+ num_layers=json_config['num_layers'],
145
+ hidden_size=json_config['hidden_size'],
146
+ state_size=json_config['state_size'],
147
+ conv_dimension=json_config['conv_dimension'],
148
+ vocab_size=json_config['vocab_size'],
149
+ expansion_factor=json_config['expansion_factor'],
150
+ mamba_moe_layers=json_config['mamba_moe_layers'],
151
+ ffn_hidden_size=json_config['ffn_hidden_size'],
152
+ bias = json_config['add_bias_linear'],
153
+ add_bias_linear = json_config['add_bias_linear'],
154
+ gated_linear_unit = json_config['swiglu']
155
+ )
156
+
157
+ model = MambaModel(config=config, max_sequence_length=json_config['max_sequence_length'], **kwargs)
158
+
159
+ # make keys match
160
+ model_state_dict["embedding.weight"] = model_state_dict["embedding.word_embeddings.weight"].clone()
161
+ model_state_dict["output_layer.weight"] = model_state_dict["embedding.word_embeddings.weight"].clone()
162
+ model_state_dict["embedding.word_embeddings.weight"] = None
163
+ model_state_dict.pop("embedding.word_embeddings.weight")
164
+ model.load_state_dict(loaded["model"])
165
+ return model
166
+
167
+ def save_pretrained(self, save_directory):
168
+ """
169
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
170
+ Save the model and its configuration file to a directory.
171
+ """
172
+ # Ensure save_directory exists
173
+ if not os.path.exists(save_directory):
174
+ os.makedirs(save_directory)
175
+
176
+ # Save the model's state_dict
177
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
178
+ torch.save(self.state_dict(), model_path)
179
+
180
+ # Save the configuration of the model
181
+ config_path = os.path.join(save_directory, 'config.json')
182
+ with open(config_path, 'w') as f:
183
+ json.dump(self.config.__dict__, f)
mlp.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Union
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from utils import bias_gelu_impl
7
+ from mamba_config import MambaConfig
8
+
9
+ class MLP(nn.Module):
10
+ def __init__(
11
+ self, config: MambaConfig, is_expert: bool = False, layer_idx=None
12
+ ):
13
+ super().__init__()
14
+
15
+ self.config: MambaConfig = config
16
+ self.layer = layer_idx
17
+ ffn_hidden_size_1 = self.config.ffn_hidden_size
18
+ ffn_hidden_size_2 = self.config.ffn_hidden_size
19
+
20
+ # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
21
+ if self.config.gated_linear_unit:
22
+ ffn_hidden_size_1 *= 2
23
+
24
+ self.linear_fc1 = nn.Linear(self.config.hidden_size, ffn_hidden_size_1, bias = self.config.add_bias_linear, device = self.config.device)
25
+ self.linear_fc1.is_expert = is_expert
26
+
27
+ if self.config.gated_linear_unit:
28
+
29
+ def glu(x):
30
+ x = torch.chunk(x, 2, dim=-1)
31
+ return self.config.activation_func(x[0]) * x[1]
32
+
33
+ self.activation_func = glu
34
+ else:
35
+ self.activation_func = self.config.activation_func
36
+
37
+ self.linear_fc2 = nn.Linear(ffn_hidden_size_2, self.config.hidden_size, bias = self.config.add_bias_linear, device = self.config.device)
38
+
39
+ def forward(self, hidden_states, inference_params=None):
40
+ intermediate = self.linear_fc1(hidden_states)
41
+ intermediate = self.activation_func(intermediate)
42
+ output = self.linear_fc2(intermediate)
43
+ return output
setup.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import warnings
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from packaging.version import parse, Version
7
+ from setuptools import setup, find_packages
8
+ import subprocess
9
+
10
+
11
+ import torch
12
+ from torch.utils.cpp_extension import (
13
+ BuildExtension,
14
+ CppExtension,
15
+ CUDAExtension,
16
+ CUDA_HOME,
17
+ )
18
+
19
+ PACKAGE_NAME = "blackmamba"
20
+ VERSION = "0.0.1"
21
+
22
+ with open("README.md", "r", encoding="utf-8") as fh:
23
+ long_description = fh.read()
24
+
25
+
26
+ # ninja build does not work unless include_dirs are abs path
27
+ this_dir = os.path.dirname(os.path.abspath(__file__))
28
+
29
+ # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
30
+ # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
31
+ FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
32
+ SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
33
+ # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
34
+ FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
35
+
36
+
37
+ def get_cuda_bare_metal_version(cuda_dir):
38
+ raw_output = subprocess.check_output(
39
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
40
+ )
41
+ output = raw_output.split()
42
+ release_idx = output.index("release") + 1
43
+ bare_metal_version = parse(output[release_idx].split(",")[0])
44
+
45
+ return raw_output, bare_metal_version
46
+
47
+
48
+ def check_if_cuda_home_none(global_option: str) -> None:
49
+ if CUDA_HOME is not None:
50
+ return
51
+ # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
52
+ # in that case.
53
+ warnings.warn(
54
+ f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
55
+ "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
56
+ "only images whose names contain 'devel' will provide nvcc."
57
+ )
58
+
59
+
60
+ def append_nvcc_threads(nvcc_extra_args):
61
+ return nvcc_extra_args + ["--threads", "4"]
62
+
63
+
64
+ ext_modules = []
65
+ if not SKIP_CUDA_BUILD:
66
+ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
67
+ TORCH_MAJOR = int(torch.__version__.split(".")[0])
68
+ TORCH_MINOR = int(torch.__version__.split(".")[1])
69
+
70
+ check_if_cuda_home_none(PACKAGE_NAME)
71
+ # Check, if CUDA11 is installed for compute capability 8.0
72
+ cc_flag = []
73
+ if CUDA_HOME is not None:
74
+ _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
75
+ if bare_metal_version < Version("11.6"):
76
+ raise RuntimeError(
77
+ f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
78
+ "Note: make sure nvcc has a supported version by running nvcc -V."
79
+ )
80
+
81
+ cc_flag.append("-gencode")
82
+ cc_flag.append("arch=compute_70,code=sm_70")
83
+ cc_flag.append("-gencode")
84
+ cc_flag.append("arch=compute_80,code=sm_80")
85
+ if bare_metal_version >= Version("11.8"):
86
+ cc_flag.append("-gencode")
87
+ cc_flag.append("arch=compute_90,code=sm_90")
88
+
89
+ # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
90
+ # torch._C._GLIBCXX_USE_CXX11_ABI
91
+ # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
92
+ if FORCE_CXX11_ABI:
93
+ torch._C._GLIBCXX_USE_CXX11_ABI = True
94
+
95
+ ext_modules.append(
96
+ CUDAExtension(
97
+ name="selective_scan_cuda",
98
+ sources=[
99
+ "csrc/selective_scan/selective_scan.cpp",
100
+ "csrc/selective_scan/selective_scan_fwd_fp32.cu",
101
+ "csrc/selective_scan/selective_scan_fwd_fp16.cu",
102
+ "csrc/selective_scan/selective_scan_fwd_bf16.cu",
103
+ "csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
104
+ "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
105
+ "csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
106
+ "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
107
+ "csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
108
+ "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
109
+ ],
110
+ extra_compile_args={
111
+ "cxx": ["-O3", "-std=c++17"],
112
+ "nvcc": append_nvcc_threads(
113
+ [
114
+ "-O3",
115
+ "-std=c++17",
116
+ "-U__CUDA_NO_HALF_OPERATORS__",
117
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
118
+ "-U__CUDA_NO_BFLOAT16_OPERATORS__",
119
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
120
+ "-U__CUDA_NO_BFLOAT162_OPERATORS__",
121
+ "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
122
+ "--expt-relaxed-constexpr",
123
+ "--expt-extended-lambda",
124
+ "--use_fast_math",
125
+ "--ptxas-options=-v",
126
+ "-lineinfo",
127
+ ]
128
+ + cc_flag
129
+ ),
130
+ },
131
+ include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
132
+ )
133
+ )
134
+
135
+
136
+ setup(
137
+ name=PACKAGE_NAME,
138
+ version=VERSION,
139
+ description="Blackmamba state-space + MoE model",
140
+ long_description=long_description,
141
+ long_description_content_type="text/markdown",
142
+ packages=find_packages(include=['ops'],),
143
+ exclude=(
144
+ "csrc",
145
+ "blackmamba.egg-info",
146
+ ),
147
+ ext_modules=ext_modules,
148
+ cmdclass={"build_ext": BuildExtension},
149
+ python_requires=">=3.7",
150
+ install_requires=[
151
+ "torch",
152
+ "packaging",
153
+ "ninja",
154
+ "einops",
155
+ "triton",
156
+ "transformers",
157
+ "causal_conv1d>=1.1.0",
158
+ ],
159
+ )
switch_mlp.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pickle
4
+ import os
5
+ import torch.nn.functional as F
6
+
7
+ from mamba_config import MambaConfig
8
+ from mlp import MLP
9
+
10
+ def sinkhorn(cost, tol=0.0001):
11
+ "Sinkhorn based MoE routing function"
12
+ cost = torch.exp(2.0 * cost)
13
+ d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
14
+ # d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
15
+ d1 = 1 / (cost.size(1) * torch.sum(cost, 0))
16
+
17
+ eps = 0.00000001
18
+ error = 1e9
19
+ d1_old = d1
20
+ while error > tol:
21
+ d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
22
+ d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
23
+ error = torch.mean(torch.abs(d1_old - d1))
24
+ d1_old = d1
25
+ return d1 * cost * d0.unsqueeze(1)
26
+
27
+
28
+ class SwitchMLP(nn.Module):
29
+ """
30
+ Top-1 Mixture of Experts Layer. Routes input to one of N MLP "experts"
31
+ Curently supports Sinkhorn based expert routing.
32
+ """
33
+
34
+ def __init__(self, config: MambaConfig, layer_idx=None):
35
+ super().__init__()
36
+
37
+ self.layer = layer_idx
38
+ self.config: MambaConfig = config
39
+ if config.mamba_moe_layers:
40
+ self.num_moe_experts = int(config.mamba_moe_layers[layer_idx-1][-1])
41
+ else:
42
+ self.num_moe_experts = self.config.num_moe_experts
43
+ self.router = torch.nn.Linear(self.config.hidden_size, self.num_moe_experts)
44
+ self.add_bias = config.add_bias_linear
45
+ self.routing = config.routing_mode # 'sinkhorn', 'top1', 'top2', 'sinkhorn_top2'
46
+ self.route_algo = sinkhorn
47
+ self.router_activation = torch.sigmoid
48
+
49
+ self.num_local_experts = self.num_moe_experts
50
+ self.local_expert_indices = [i for i in range(self.num_local_experts)]
51
+
52
+ self.local_experts = torch.nn.ModuleList()
53
+ for _ in range(self.num_local_experts):
54
+ expert = MLP(self.config, is_expert=True, layer_idx=layer_idx)
55
+ self.local_experts.append(expert)
56
+
57
+ def gather_indices(self, local_indices):
58
+ return local_indices
59
+
60
+ def forward(self, hidden_states, inference_params=None):
61
+
62
+ hidden_shape = hidden_states.shape
63
+ route = self.router(hidden_states)
64
+ route = route.view(-1, self.num_moe_experts)
65
+
66
+ if self.routing == 'sinkhorn':
67
+ route = self.router_activation(route)
68
+ max_prob, max_ind = torch.max(route, dim=1)
69
+ else:
70
+ route = torch.softmax(route, dim=1)
71
+ max_prob, max_ind = torch.max(route, dim=1)
72
+
73
+ max_prob = torch.unsqueeze(max_prob, 1)
74
+ hidden_states = hidden_states.view(-1, hidden_shape[-1])
75
+
76
+ global_hidden_states = hidden_states
77
+ global_indices = max_ind
78
+ output_total = torch.zeros_like(global_hidden_states)
79
+
80
+
81
+ for expert_num, expert in enumerate(self.local_experts):
82
+ local_expert_index = self.local_expert_indices[expert_num]
83
+ local_indices = (global_indices == local_expert_index).nonzero()
84
+ hidden = global_hidden_states[local_indices, :]
85
+ output = expert(hidden)
86
+ output_total[local_indices, :] = output
87
+
88
+ output_total = output_total * max_prob
89
+ output_total = output_total.view(hidden_shape)
90
+
91
+ return output_total
utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from operator import itemgetter
2
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
3
+ import math
4
+ import torch
5
+
6
+
7
+ def attention_mask_func(attention_scores, attention_mask):
8
+ attention_scores.masked_fill_(attention_mask, -10000.0)
9
+ return attention_scores
10
+
11
+
12
+ @torch.jit.script
13
+ def gelu_impl(x):
14
+ """OpenAI's gelu implementation."""
15
+ return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
16
+
17
+
18
+ def openai_gelu(x):
19
+ return gelu_impl(x)
20
+
21
+
22
+ @torch.jit.script
23
+ def bias_gelu(bias, y):
24
+ x = bias + y
25
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
26
+
27
+
28
+ # gradient of tanh approximation of gelu
29
+ # gradient of actual gelu is:
30
+ # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
31
+ @torch.jit.script
32
+ def bias_gelu_back(g, bias, y):
33
+ x = bias + y
34
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
35
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
36
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
37
+ 1 + tanh_out
38
+ )
39
+ return ff * g
40
+
41
+
42
+ class GeLUFunction(torch.autograd.Function):
43
+ @staticmethod
44
+ # bias is an optional argument
45
+ def forward(ctx, input, bias):
46
+ ctx.save_for_backward(input, bias)
47
+ return bias_gelu(bias, input)
48
+
49
+ @staticmethod
50
+ def backward(ctx, grad_output):
51
+ input, bias = ctx.saved_tensors
52
+ tmp = bias_gelu_back(grad_output, bias, input)
53
+ return tmp, tmp
54
+
55
+
56
+ bias_gelu_impl = GeLUFunction.apply
57
+
58
+
59
+
60
+ # This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
61
+ @torch.jit.script
62
+ def erf_gelu(x):
63
+ return (
64
+ x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
65
+ )
66
+
67
+
68
+ def init_method_normal(sigma):
69
+
70
+ def init_(tensor):
71
+ return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
72
+
73
+ return init_
74
+
75
+
76
+ def scaled_init_method_normal(sigma, num_layers):
77
+ std = sigma / math.sqrt(2.0 * num_layers)
78
+
79
+ def init_(tensor):
80
+ return torch.nn.init.normal_(tensor, mean=0.0, std=std)
81
+
82
+ return init_