refactor changes
Browse files- modeling_cerule_gemma.py +7 -7
modeling_cerule_gemma.py
CHANGED
@@ -872,7 +872,7 @@ if is_torch_fx_available():
|
|
872 |
|
873 |
logger = logging.get_logger(__name__)
|
874 |
|
875 |
-
_CONFIG_FOR_DOC = "
|
876 |
|
877 |
|
878 |
def _get_unpad_data(attention_mask):
|
@@ -1003,7 +1003,7 @@ class GemmaAttention(nn.Module):
|
|
1003 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
1004 |
|
1005 |
# Ignore copy
|
1006 |
-
def __init__(self, config:
|
1007 |
super().__init__()
|
1008 |
self.config = config
|
1009 |
self.layer_idx = layer_idx
|
@@ -1396,7 +1396,7 @@ GEMMA_ATTENTION_CLASSES = {
|
|
1396 |
|
1397 |
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
|
1398 |
class GemmaDecoderLayer(nn.Module):
|
1399 |
-
def __init__(self, config:
|
1400 |
super().__init__()
|
1401 |
self.hidden_size = config.hidden_size
|
1402 |
|
@@ -1480,7 +1480,7 @@ GEMMA_START_DOCSTRING = r"""
|
|
1480 |
and behavior.
|
1481 |
|
1482 |
Parameters:
|
1483 |
-
config ([`
|
1484 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
1485 |
load the weights associated with the model, only the configuration. Check out the
|
1486 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
@@ -1492,7 +1492,7 @@ GEMMA_START_DOCSTRING = r"""
|
|
1492 |
GEMMA_START_DOCSTRING,
|
1493 |
)
|
1494 |
class GemmaPreTrainedModel(PreTrainedModel):
|
1495 |
-
config_class =
|
1496 |
base_model_prefix = "model"
|
1497 |
supports_gradient_checkpointing = True
|
1498 |
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
|
@@ -1618,7 +1618,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|
1618 |
config: GemmaConfig
|
1619 |
"""
|
1620 |
|
1621 |
-
def __init__(self, config:
|
1622 |
super().__init__(config)
|
1623 |
self.padding_idx = config.pad_token_id
|
1624 |
self.vocab_size = config.vocab_size
|
@@ -2155,7 +2155,7 @@ from .configuration_gemma import CeruleGemmaConfig
|
|
2155 |
class CeruleGemmaModel(CeruleMetaModel, GemmaModel):
|
2156 |
config_class = CeruleGemmaConfig
|
2157 |
|
2158 |
-
def __init__(self, config:
|
2159 |
super(CeruleGemmaModel, self).__init__(config)
|
2160 |
|
2161 |
|
|
|
872 |
|
873 |
logger = logging.get_logger(__name__)
|
874 |
|
875 |
+
_CONFIG_FOR_DOC = "CeruleGemmaConfig"
|
876 |
|
877 |
|
878 |
def _get_unpad_data(attention_mask):
|
|
|
1003 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
1004 |
|
1005 |
# Ignore copy
|
1006 |
+
def __init__(self, config: CeruleGemmaConfig, layer_idx: Optional[int] = None):
|
1007 |
super().__init__()
|
1008 |
self.config = config
|
1009 |
self.layer_idx = layer_idx
|
|
|
1396 |
|
1397 |
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
|
1398 |
class GemmaDecoderLayer(nn.Module):
|
1399 |
+
def __init__(self, config: CeruleGemmaConfig, layer_idx: int):
|
1400 |
super().__init__()
|
1401 |
self.hidden_size = config.hidden_size
|
1402 |
|
|
|
1480 |
and behavior.
|
1481 |
|
1482 |
Parameters:
|
1483 |
+
config ([`CeruleGemmaConfig`]):
|
1484 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
1485 |
load the weights associated with the model, only the configuration. Check out the
|
1486 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
|
1492 |
GEMMA_START_DOCSTRING,
|
1493 |
)
|
1494 |
class GemmaPreTrainedModel(PreTrainedModel):
|
1495 |
+
config_class = CeruleGemmaConfig
|
1496 |
base_model_prefix = "model"
|
1497 |
supports_gradient_checkpointing = True
|
1498 |
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
|
|
|
1618 |
config: GemmaConfig
|
1619 |
"""
|
1620 |
|
1621 |
+
def __init__(self, config: CeruleGemmaConfig):
|
1622 |
super().__init__(config)
|
1623 |
self.padding_idx = config.pad_token_id
|
1624 |
self.vocab_size = config.vocab_size
|
|
|
2155 |
class CeruleGemmaModel(CeruleMetaModel, GemmaModel):
|
2156 |
config_class = CeruleGemmaConfig
|
2157 |
|
2158 |
+
def __init__(self, config: CeruleGemmaConfig):
|
2159 |
super(CeruleGemmaModel, self).__init__(config)
|
2160 |
|
2161 |
|