Adding FlexBertConfig as the config_class
Browse files- modeling_flexbert.py +1 -1
modeling_flexbert.py
CHANGED
@@ -924,7 +924,7 @@ class FlexBertModel(FlexBertPreTrainedModel):
|
|
924 |
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
925 |
```
|
926 |
"""
|
927 |
-
|
928 |
def __init__(self, config: FlexBertConfig):
|
929 |
super().__init__(config)
|
930 |
self.embeddings = get_embedding_layer(config)
|
|
|
924 |
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
925 |
```
|
926 |
"""
|
927 |
+
config_class = FlexBertConfig
|
928 |
def __init__(self, config: FlexBertConfig):
|
929 |
super().__init__(config)
|
930 |
self.embeddings = get_embedding_layer(config)
|