Fixes
Browse files- __init__.py +2 -2
- __pycache__/__init__.cpython-311.pyc +0 -0
- __pycache__/activation.cpython-311.pyc +0 -0
- __pycache__/attention.cpython-311.pyc +0 -0
- __pycache__/bert_padding.cpython-311.pyc +0 -0
- __pycache__/configuration_bert.cpython-311.pyc +0 -0
- __pycache__/embeddings.cpython-311.pyc +0 -0
- __pycache__/initialization.cpython-311.pyc +0 -0
- __pycache__/layers.cpython-311.pyc +0 -0
- __pycache__/mlp.cpython-311.pyc +0 -0
- __pycache__/modeling_flexbert.cpython-311.pyc +0 -0
- __pycache__/normalization.cpython-311.pyc +0 -0
- __pycache__/padding.cpython-311.pyc +0 -0
- __pycache__/rotary.cpython-311.pyc +0 -0
- __pycache__/utils.cpython-311.pyc +0 -0
- attention.py +5 -2
- initialization.py +1 -1
- modeling_flexbert.py +13 -14
__init__.py
CHANGED
@@ -19,7 +19,7 @@ from .layers import (
|
|
19 |
FlexBertUnpadPostNormLayer,
|
20 |
FlexBertUnpadPreNormLayer,
|
21 |
)
|
22 |
-
from .
|
23 |
BertLMPredictionHead,
|
24 |
BertModel,
|
25 |
BertForMaskedLM,
|
@@ -68,6 +68,6 @@ __all__ = [
|
|
68 |
"FlexBertForMaskedLM",
|
69 |
"FlexBertForSequenceClassification",
|
70 |
"FlexBertForMultipleChoice",
|
71 |
-
"IndexFirstAxis,
|
72 |
"IndexPutFirstAxis"
|
73 |
]
|
|
|
19 |
FlexBertUnpadPostNormLayer,
|
20 |
FlexBertUnpadPreNormLayer,
|
21 |
)
|
22 |
+
from .modeling_flexbert import (
|
23 |
BertLMPredictionHead,
|
24 |
BertModel,
|
25 |
BertForMaskedLM,
|
|
|
68 |
"FlexBertForMaskedLM",
|
69 |
"FlexBertForSequenceClassification",
|
70 |
"FlexBertForMultipleChoice",
|
71 |
+
"IndexFirstAxis",
|
72 |
"IndexPutFirstAxis"
|
73 |
]
|
__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/__pycache__/__init__.cpython-311.pyc and b/__pycache__/__init__.cpython-311.pyc differ
|
|
__pycache__/activation.cpython-311.pyc
CHANGED
Binary files a/__pycache__/activation.cpython-311.pyc and b/__pycache__/activation.cpython-311.pyc differ
|
|
__pycache__/attention.cpython-311.pyc
CHANGED
Binary files a/__pycache__/attention.cpython-311.pyc and b/__pycache__/attention.cpython-311.pyc differ
|
|
__pycache__/bert_padding.cpython-311.pyc
CHANGED
Binary files a/__pycache__/bert_padding.cpython-311.pyc and b/__pycache__/bert_padding.cpython-311.pyc differ
|
|
__pycache__/configuration_bert.cpython-311.pyc
CHANGED
Binary files a/__pycache__/configuration_bert.cpython-311.pyc and b/__pycache__/configuration_bert.cpython-311.pyc differ
|
|
__pycache__/embeddings.cpython-311.pyc
CHANGED
Binary files a/__pycache__/embeddings.cpython-311.pyc and b/__pycache__/embeddings.cpython-311.pyc differ
|
|
__pycache__/initialization.cpython-311.pyc
CHANGED
Binary files a/__pycache__/initialization.cpython-311.pyc and b/__pycache__/initialization.cpython-311.pyc differ
|
|
__pycache__/layers.cpython-311.pyc
CHANGED
Binary files a/__pycache__/layers.cpython-311.pyc and b/__pycache__/layers.cpython-311.pyc differ
|
|
__pycache__/mlp.cpython-311.pyc
CHANGED
Binary files a/__pycache__/mlp.cpython-311.pyc and b/__pycache__/mlp.cpython-311.pyc differ
|
|
__pycache__/modeling_flexbert.cpython-311.pyc
CHANGED
Binary files a/__pycache__/modeling_flexbert.cpython-311.pyc and b/__pycache__/modeling_flexbert.cpython-311.pyc differ
|
|
__pycache__/normalization.cpython-311.pyc
CHANGED
Binary files a/__pycache__/normalization.cpython-311.pyc and b/__pycache__/normalization.cpython-311.pyc differ
|
|
__pycache__/padding.cpython-311.pyc
CHANGED
Binary files a/__pycache__/padding.cpython-311.pyc and b/__pycache__/padding.cpython-311.pyc differ
|
|
__pycache__/rotary.cpython-311.pyc
CHANGED
Binary files a/__pycache__/rotary.cpython-311.pyc and b/__pycache__/rotary.cpython-311.pyc differ
|
|
__pycache__/utils.cpython-311.pyc
CHANGED
Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ
|
|
attention.py
CHANGED
@@ -20,12 +20,15 @@ from typing import Optional
|
|
20 |
import importlib.metadata
|
21 |
import logging
|
22 |
import math
|
23 |
-
|
|
|
|
|
|
|
24 |
import bert_padding
|
25 |
from .configuration_bert import FlexBertConfig, maybe_add_padding
|
26 |
from .normalization import get_norm_layer
|
27 |
from .initialization import ModuleType, init_weights
|
28 |
-
import
|
29 |
|
30 |
IMPL_USE_FLASH3 = False
|
31 |
IMPL_USE_FLASH2 = False
|
|
|
20 |
import importlib.metadata
|
21 |
import logging
|
22 |
import math
|
23 |
+
import sys
|
24 |
+
import os
|
25 |
+
# Add src folder root to path to allow us to use relative imports regardless of what directory the script is run from
|
26 |
+
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
27 |
import bert_padding
|
28 |
from .configuration_bert import FlexBertConfig, maybe_add_padding
|
29 |
from .normalization import get_norm_layer
|
30 |
from .initialization import ModuleType, init_weights
|
31 |
+
import utils # noqa: F401
|
32 |
|
33 |
IMPL_USE_FLASH3 = False
|
34 |
IMPL_USE_FLASH2 = False
|
initialization.py
CHANGED
@@ -14,7 +14,7 @@ from typing import Optional, Union
|
|
14 |
import torch
|
15 |
import torch.nn as nn
|
16 |
|
17 |
-
from
|
18 |
|
19 |
from .configuration_bert import FlexBertConfig
|
20 |
from .normalization import RMSNorm
|
|
|
14 |
import torch
|
15 |
import torch.nn as nn
|
16 |
|
17 |
+
from utils import StrEnum
|
18 |
|
19 |
from .configuration_bert import FlexBertConfig
|
20 |
from .normalization import RMSNorm
|
modeling_flexbert.py
CHANGED
@@ -69,8 +69,8 @@ from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
|
69 |
|
70 |
from bert_padding import index_put_first_axis
|
71 |
|
72 |
-
from
|
73 |
-
from
|
74 |
FlexBertPaddedAttention,
|
75 |
FlexBertPaddedParallelAttention,
|
76 |
FlexBertPaddedRopeAttention,
|
@@ -80,15 +80,15 @@ from src.bert_layers.attention import (
|
|
80 |
FlexBertUnpadRopeAttention,
|
81 |
FlexBertUnpadRopeParallelAttention,
|
82 |
)
|
83 |
-
from
|
84 |
-
from
|
85 |
BertAlibiEmbeddings,
|
86 |
FlexBertAbsoluteEmbeddings,
|
87 |
FlexBertCompiledSansPositionEmbeddings,
|
88 |
FlexBertSansPositionEmbeddings,
|
89 |
get_embedding_layer,
|
90 |
)
|
91 |
-
from
|
92 |
ModuleType,
|
93 |
TileLinear,
|
94 |
TileMode,
|
@@ -97,7 +97,7 @@ from src.bert_layers.initialization import (
|
|
97 |
tile_linear,
|
98 |
tile_norm,
|
99 |
)
|
100 |
-
from
|
101 |
BertAlibiEncoder,
|
102 |
BertPooler,
|
103 |
BertPredictionHeadTransform,
|
@@ -112,10 +112,9 @@ from src.bert_layers.layers import (
|
|
112 |
FlexBertUnpadPreNormLayer,
|
113 |
get_encoder_layer,
|
114 |
)
|
115 |
-
from
|
116 |
-
from
|
117 |
-
from
|
118 |
-
from src.bert_layers.padding import pad_input, unpad_input
|
119 |
|
120 |
logger = logging.getLogger(__name__)
|
121 |
|
@@ -867,14 +866,16 @@ class FlexBertPreTrainedModel(BertPreTrainedModel):
|
|
867 |
|
868 |
def _init_module_weights(self, module: nn.Module):
|
869 |
"""
|
870 |
-
Custom weight init of modules using
|
871 |
Currently only supports init of embedding modules
|
872 |
"""
|
873 |
assert isinstance(module, nn.Module)
|
874 |
if isinstance(module, nn.Embedding):
|
875 |
init_weights(self.config, module, type_of_module=ModuleType.emb)
|
876 |
else:
|
877 |
-
|
|
|
|
|
878 |
|
879 |
|
880 |
class FlexBertModel(FlexBertPreTrainedModel):
|
@@ -1010,8 +1011,6 @@ class FlexBertForMaskedLM(FlexBertPreTrainedModel):
|
|
1010 |
decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight
|
1011 |
self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias)
|
1012 |
self.decoder.weight = decoder_weights
|
1013 |
-
|
1014 |
-
self.loss_fn = nn.CrossEntropyLoss() if not hasattr(config, "loss_function") else get_loss_fn(config)
|
1015 |
self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy"
|
1016 |
self.return_z_loss = config.loss_kwargs.get("return_z_loss", False)
|
1017 |
self.unpad_embeddings = config.unpad_embeddings
|
|
|
69 |
|
70 |
from bert_padding import index_put_first_axis
|
71 |
|
72 |
+
from .activation import get_act_fn
|
73 |
+
from .attention import (
|
74 |
FlexBertPaddedAttention,
|
75 |
FlexBertPaddedParallelAttention,
|
76 |
FlexBertPaddedRopeAttention,
|
|
|
80 |
FlexBertUnpadRopeAttention,
|
81 |
FlexBertUnpadRopeParallelAttention,
|
82 |
)
|
83 |
+
from .configuration_bert import FlexBertConfig
|
84 |
+
from .embeddings import (
|
85 |
BertAlibiEmbeddings,
|
86 |
FlexBertAbsoluteEmbeddings,
|
87 |
FlexBertCompiledSansPositionEmbeddings,
|
88 |
FlexBertSansPositionEmbeddings,
|
89 |
get_embedding_layer,
|
90 |
)
|
91 |
+
from .initialization import (
|
92 |
ModuleType,
|
93 |
TileLinear,
|
94 |
TileMode,
|
|
|
97 |
tile_linear,
|
98 |
tile_norm,
|
99 |
)
|
100 |
+
from .layers import (
|
101 |
BertAlibiEncoder,
|
102 |
BertPooler,
|
103 |
BertPredictionHeadTransform,
|
|
|
112 |
FlexBertUnpadPreNormLayer,
|
113 |
get_encoder_layer,
|
114 |
)
|
115 |
+
from .mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
|
116 |
+
from .normalization import get_norm_layer
|
117 |
+
from .padding import pad_input, unpad_input
|
|
|
118 |
|
119 |
logger = logging.getLogger(__name__)
|
120 |
|
|
|
866 |
|
867 |
def _init_module_weights(self, module: nn.Module):
|
868 |
"""
|
869 |
+
Custom weight init of modules using .initialization.init_weights
|
870 |
Currently only supports init of embedding modules
|
871 |
"""
|
872 |
assert isinstance(module, nn.Module)
|
873 |
if isinstance(module, nn.Embedding):
|
874 |
init_weights(self.config, module, type_of_module=ModuleType.emb)
|
875 |
else:
|
876 |
+
print("Custom weight init for the given module is not supported")
|
877 |
+
print(module)
|
878 |
+
# raise NotImplementedError("Custom weight init for the given module is not supported")
|
879 |
|
880 |
|
881 |
class FlexBertModel(FlexBertPreTrainedModel):
|
|
|
1011 |
decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight
|
1012 |
self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias)
|
1013 |
self.decoder.weight = decoder_weights
|
|
|
|
|
1014 |
self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy"
|
1015 |
self.return_z_loss = config.loss_kwargs.get("return_z_loss", False)
|
1016 |
self.unpad_embeddings = config.unpad_embeddings
|