NohTow commited on
Commit
6088290
1 Parent(s): 995cbf0
__init__.py CHANGED
@@ -19,7 +19,7 @@ from .layers import (
19
  FlexBertUnpadPostNormLayer,
20
  FlexBertUnpadPreNormLayer,
21
  )
22
- from .model import (
23
  BertLMPredictionHead,
24
  BertModel,
25
  BertForMaskedLM,
@@ -68,6 +68,6 @@ __all__ = [
68
  "FlexBertForMaskedLM",
69
  "FlexBertForSequenceClassification",
70
  "FlexBertForMultipleChoice",
71
- "IndexFirstAxis,
72
  "IndexPutFirstAxis"
73
  ]
 
19
  FlexBertUnpadPostNormLayer,
20
  FlexBertUnpadPreNormLayer,
21
  )
22
+ from .modeling_flexbert import (
23
  BertLMPredictionHead,
24
  BertModel,
25
  BertForMaskedLM,
 
68
  "FlexBertForMaskedLM",
69
  "FlexBertForSequenceClassification",
70
  "FlexBertForMultipleChoice",
71
+ "IndexFirstAxis",
72
  "IndexPutFirstAxis"
73
  ]
__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/__pycache__/__init__.cpython-311.pyc and b/__pycache__/__init__.cpython-311.pyc differ
 
__pycache__/activation.cpython-311.pyc CHANGED
Binary files a/__pycache__/activation.cpython-311.pyc and b/__pycache__/activation.cpython-311.pyc differ
 
__pycache__/attention.cpython-311.pyc CHANGED
Binary files a/__pycache__/attention.cpython-311.pyc and b/__pycache__/attention.cpython-311.pyc differ
 
__pycache__/bert_padding.cpython-311.pyc CHANGED
Binary files a/__pycache__/bert_padding.cpython-311.pyc and b/__pycache__/bert_padding.cpython-311.pyc differ
 
__pycache__/configuration_bert.cpython-311.pyc CHANGED
Binary files a/__pycache__/configuration_bert.cpython-311.pyc and b/__pycache__/configuration_bert.cpython-311.pyc differ
 
__pycache__/embeddings.cpython-311.pyc CHANGED
Binary files a/__pycache__/embeddings.cpython-311.pyc and b/__pycache__/embeddings.cpython-311.pyc differ
 
__pycache__/initialization.cpython-311.pyc CHANGED
Binary files a/__pycache__/initialization.cpython-311.pyc and b/__pycache__/initialization.cpython-311.pyc differ
 
__pycache__/layers.cpython-311.pyc CHANGED
Binary files a/__pycache__/layers.cpython-311.pyc and b/__pycache__/layers.cpython-311.pyc differ
 
__pycache__/mlp.cpython-311.pyc CHANGED
Binary files a/__pycache__/mlp.cpython-311.pyc and b/__pycache__/mlp.cpython-311.pyc differ
 
__pycache__/modeling_flexbert.cpython-311.pyc CHANGED
Binary files a/__pycache__/modeling_flexbert.cpython-311.pyc and b/__pycache__/modeling_flexbert.cpython-311.pyc differ
 
__pycache__/normalization.cpython-311.pyc CHANGED
Binary files a/__pycache__/normalization.cpython-311.pyc and b/__pycache__/normalization.cpython-311.pyc differ
 
__pycache__/padding.cpython-311.pyc CHANGED
Binary files a/__pycache__/padding.cpython-311.pyc and b/__pycache__/padding.cpython-311.pyc differ
 
__pycache__/rotary.cpython-311.pyc CHANGED
Binary files a/__pycache__/rotary.cpython-311.pyc and b/__pycache__/rotary.cpython-311.pyc differ
 
__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ
 
attention.py CHANGED
@@ -20,12 +20,15 @@ from typing import Optional
20
  import importlib.metadata
21
  import logging
22
  import math
23
-
 
 
 
24
  import bert_padding
25
  from .configuration_bert import FlexBertConfig, maybe_add_padding
26
  from .normalization import get_norm_layer
27
  from .initialization import ModuleType, init_weights
28
- import src.utils # noqa: F401
29
 
30
  IMPL_USE_FLASH3 = False
31
  IMPL_USE_FLASH2 = False
 
20
  import importlib.metadata
21
  import logging
22
  import math
23
+ import sys
24
+ import os
25
+ # Add src folder root to path to allow us to use relative imports regardless of what directory the script is run from
26
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
27
  import bert_padding
28
  from .configuration_bert import FlexBertConfig, maybe_add_padding
29
  from .normalization import get_norm_layer
30
  from .initialization import ModuleType, init_weights
31
+ import utils # noqa: F401
32
 
33
  IMPL_USE_FLASH3 = False
34
  IMPL_USE_FLASH2 = False
initialization.py CHANGED
@@ -14,7 +14,7 @@ from typing import Optional, Union
14
  import torch
15
  import torch.nn as nn
16
 
17
- from src.utils import StrEnum
18
 
19
  from .configuration_bert import FlexBertConfig
20
  from .normalization import RMSNorm
 
14
  import torch
15
  import torch.nn as nn
16
 
17
+ from utils import StrEnum
18
 
19
  from .configuration_bert import FlexBertConfig
20
  from .normalization import RMSNorm
modeling_flexbert.py CHANGED
@@ -69,8 +69,8 @@ from transformers.models.bert.modeling_bert import BertPreTrainedModel
69
 
70
  from bert_padding import index_put_first_axis
71
 
72
- from src.bert_layers.activation import get_act_fn
73
- from src.bert_layers.attention import (
74
  FlexBertPaddedAttention,
75
  FlexBertPaddedParallelAttention,
76
  FlexBertPaddedRopeAttention,
@@ -80,15 +80,15 @@ from src.bert_layers.attention import (
80
  FlexBertUnpadRopeAttention,
81
  FlexBertUnpadRopeParallelAttention,
82
  )
83
- from src.bert_layers.configuration_bert import FlexBertConfig
84
- from src.bert_layers.embeddings import (
85
  BertAlibiEmbeddings,
86
  FlexBertAbsoluteEmbeddings,
87
  FlexBertCompiledSansPositionEmbeddings,
88
  FlexBertSansPositionEmbeddings,
89
  get_embedding_layer,
90
  )
91
- from src.bert_layers.initialization import (
92
  ModuleType,
93
  TileLinear,
94
  TileMode,
@@ -97,7 +97,7 @@ from src.bert_layers.initialization import (
97
  tile_linear,
98
  tile_norm,
99
  )
100
- from src.bert_layers.layers import (
101
  BertAlibiEncoder,
102
  BertPooler,
103
  BertPredictionHeadTransform,
@@ -112,10 +112,9 @@ from src.bert_layers.layers import (
112
  FlexBertUnpadPreNormLayer,
113
  get_encoder_layer,
114
  )
115
- from src.bert_layers.loss import get_loss_fn
116
- from src.bert_layers.mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
117
- from src.bert_layers.normalization import get_norm_layer
118
- from src.bert_layers.padding import pad_input, unpad_input
119
 
120
  logger = logging.getLogger(__name__)
121
 
@@ -867,14 +866,16 @@ class FlexBertPreTrainedModel(BertPreTrainedModel):
867
 
868
  def _init_module_weights(self, module: nn.Module):
869
  """
870
- Custom weight init of modules using src.bert_layers.initialization.init_weights
871
  Currently only supports init of embedding modules
872
  """
873
  assert isinstance(module, nn.Module)
874
  if isinstance(module, nn.Embedding):
875
  init_weights(self.config, module, type_of_module=ModuleType.emb)
876
  else:
877
- raise NotImplementedError("Custom weight init for the given module is not supported")
 
 
878
 
879
 
880
  class FlexBertModel(FlexBertPreTrainedModel):
@@ -1010,8 +1011,6 @@ class FlexBertForMaskedLM(FlexBertPreTrainedModel):
1010
  decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight
1011
  self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias)
1012
  self.decoder.weight = decoder_weights
1013
-
1014
- self.loss_fn = nn.CrossEntropyLoss() if not hasattr(config, "loss_function") else get_loss_fn(config)
1015
  self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy"
1016
  self.return_z_loss = config.loss_kwargs.get("return_z_loss", False)
1017
  self.unpad_embeddings = config.unpad_embeddings
 
69
 
70
  from bert_padding import index_put_first_axis
71
 
72
+ from .activation import get_act_fn
73
+ from .attention import (
74
  FlexBertPaddedAttention,
75
  FlexBertPaddedParallelAttention,
76
  FlexBertPaddedRopeAttention,
 
80
  FlexBertUnpadRopeAttention,
81
  FlexBertUnpadRopeParallelAttention,
82
  )
83
+ from .configuration_bert import FlexBertConfig
84
+ from .embeddings import (
85
  BertAlibiEmbeddings,
86
  FlexBertAbsoluteEmbeddings,
87
  FlexBertCompiledSansPositionEmbeddings,
88
  FlexBertSansPositionEmbeddings,
89
  get_embedding_layer,
90
  )
91
+ from .initialization import (
92
  ModuleType,
93
  TileLinear,
94
  TileMode,
 
97
  tile_linear,
98
  tile_norm,
99
  )
100
+ from .layers import (
101
  BertAlibiEncoder,
102
  BertPooler,
103
  BertPredictionHeadTransform,
 
112
  FlexBertUnpadPreNormLayer,
113
  get_encoder_layer,
114
  )
115
+ from .mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
116
+ from .normalization import get_norm_layer
117
+ from .padding import pad_input, unpad_input
 
118
 
119
  logger = logging.getLogger(__name__)
120
 
 
866
 
867
  def _init_module_weights(self, module: nn.Module):
868
  """
869
+ Custom weight init of modules using .initialization.init_weights
870
  Currently only supports init of embedding modules
871
  """
872
  assert isinstance(module, nn.Module)
873
  if isinstance(module, nn.Embedding):
874
  init_weights(self.config, module, type_of_module=ModuleType.emb)
875
  else:
876
+ print("Custom weight init for the given module is not supported")
877
+ print(module)
878
+ # raise NotImplementedError("Custom weight init for the given module is not supported")
879
 
880
 
881
  class FlexBertModel(FlexBertPreTrainedModel):
 
1011
  decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight
1012
  self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias)
1013
  self.decoder.weight = decoder_weights
 
 
1014
  self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy"
1015
  self.return_z_loss = config.loss_kwargs.get("return_z_loss", False)
1016
  self.unpad_embeddings = config.unpad_embeddings