fix integration with huggingface

#2
README.md CHANGED
@@ -39,6 +39,16 @@ The training setup was `4xA100's 80GB` and took ~6 hours to pretrain and ~13 hou
39
  | ![extreme_ironing](examples/extreme_ironing.jpg) | **What's funny about this image?**<br>The image is quite humorous as it depicts a man ironing clothes on the back of a yellow taxi cab. This is not a typical sight you'd expect to see in everyday life. |
40
  ---
41
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  ## Training:
44
  We will release the training code in some time.
 
39
  | ![extreme_ironing](examples/extreme_ironing.jpg) | **What's funny about this image?**<br>The image is quite humorous as it depicts a man ironing clothes on the back of a yellow taxi cab. This is not a typical sight you'd expect to see in everyday life. |
40
  ---
41
 
42
+ ## Loading the model
43
+
44
+ ```
45
+ pip install -qr https://huggingface.co/Tensoic/Cerule-v0.1/resolve/main/requirements.txt
46
+ ```
47
+
48
+ ```python
49
+ from transformers import AutoModelForCausalLM
50
+ model = AutoModelForCausalLM.from_pretrained("Tensoic/Cerule-v0.1", trust_remote_code=True)
51
+ ```
52
 
53
  ## Training:
54
  We will release the training code in some time.
__init__.py CHANGED
@@ -3,5 +3,5 @@ from .modeling_cerule_gemma import CeruleGemmaForCausalLM
3
 
4
  from transformers import AutoConfig, AutoModelForCausalLM
5
 
6
- AutoConfig.register("cerule-gemma", CeruleGemmaConfig)
7
  AutoModelForCausalLM.register(CeruleGemmaConfig, CeruleGemmaForCausalLM)
 
3
 
4
  from transformers import AutoConfig, AutoModelForCausalLM
5
 
6
+ AutoConfig.register("phi-msft", CeruleGemmaConfig)
7
  AutoModelForCausalLM.register(CeruleGemmaConfig, CeruleGemmaForCausalLM)
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "Tensoic/Cerule",
3
  "architectures": [
4
  "CeruleGemmaForCausalLM"
5
  ],
@@ -23,7 +23,7 @@
23
  "mm_projector_lr": null,
24
  "mm_projector_type": "mlp2x_gelu",
25
  "mm_vision_tower": "google/siglip-so400m-patch14-384",
26
- "model_type": "cerule-gemma",
27
  "num_attention_heads": 8,
28
  "num_hidden_layers": 18,
29
  "num_key_value_heads": 1,
 
1
  {
2
+ "_name_or_path": "Tensoic/Cerule-v0.1",
3
  "architectures": [
4
  "CeruleGemmaForCausalLM"
5
  ],
 
23
  "mm_projector_lr": null,
24
  "mm_projector_type": "mlp2x_gelu",
25
  "mm_vision_tower": "google/siglip-so400m-patch14-384",
26
+ "model_type": "phi-msft",
27
  "num_attention_heads": 8,
28
  "num_hidden_layers": 18,
29
  "num_key_value_heads": 1,
configuration_gemma.py CHANGED
@@ -25,8 +25,8 @@ GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
  }
26
 
27
 
28
- class GemmaConfig(PretrainedConfig):
29
- model_type = "gemma"
30
  keys_to_ignore_at_inference = ["past_key_values"]
31
 
32
  def __init__(
@@ -162,10 +162,3 @@ class SigLipVisionConfig(PretrainedConfig):
162
 
163
  return cls.from_dict(config_dict, **kwargs)
164
 
165
-
166
- class CeruleGemmaConfig(GemmaConfig):
167
- model_type = "cerule-gemma"
168
-
169
- def __init__(self, **kwargs):
170
- self.gemma_config = GemmaConfig(**kwargs)
171
- super().__init__(**kwargs)
 
25
  }
26
 
27
 
28
+ class CeruleGemmaConfig(PretrainedConfig):
29
+ model_type = "phi-msft"
30
  keys_to_ignore_at_inference = ["past_key_values"]
31
 
32
  def __init__(
 
162
 
163
  return cls.from_dict(config_dict, **kwargs)
164
 
 
 
 
 
 
 
 
modeling_cerule_gemma.py CHANGED
@@ -853,7 +853,7 @@ from transformers.utils import (
853
  replace_return_docstrings,
854
  )
855
  from transformers.utils.import_utils import is_torch_fx_available
856
- from .configuration_gemma import GemmaConfig
857
 
858
 
859
  if is_flash_attn_2_available():
@@ -872,7 +872,7 @@ if is_torch_fx_available():
872
 
873
  logger = logging.get_logger(__name__)
874
 
875
- _CONFIG_FOR_DOC = "GemmaConfig"
876
 
877
 
878
  def _get_unpad_data(attention_mask):
@@ -1003,7 +1003,7 @@ class GemmaAttention(nn.Module):
1003
  """Multi-headed attention from 'Attention Is All You Need' paper"""
1004
 
1005
  # Ignore copy
1006
- def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
1007
  super().__init__()
1008
  self.config = config
1009
  self.layer_idx = layer_idx
@@ -1396,7 +1396,7 @@ GEMMA_ATTENTION_CLASSES = {
1396
 
1397
  # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
1398
  class GemmaDecoderLayer(nn.Module):
1399
- def __init__(self, config: GemmaConfig, layer_idx: int):
1400
  super().__init__()
1401
  self.hidden_size = config.hidden_size
1402
 
@@ -1480,7 +1480,7 @@ GEMMA_START_DOCSTRING = r"""
1480
  and behavior.
1481
 
1482
  Parameters:
1483
- config ([`GemmaConfig`]):
1484
  Model configuration class with all the parameters of the model. Initializing with a config file does not
1485
  load the weights associated with the model, only the configuration. Check out the
1486
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
@@ -1492,7 +1492,7 @@ GEMMA_START_DOCSTRING = r"""
1492
  GEMMA_START_DOCSTRING,
1493
  )
1494
  class GemmaPreTrainedModel(PreTrainedModel):
1495
- config_class = GemmaConfig
1496
  base_model_prefix = "model"
1497
  supports_gradient_checkpointing = True
1498
  _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
@@ -1618,7 +1618,7 @@ class GemmaModel(GemmaPreTrainedModel):
1618
  config: GemmaConfig
1619
  """
1620
 
1621
- def __init__(self, config: GemmaConfig):
1622
  super().__init__(config)
1623
  self.padding_idx = config.pad_token_id
1624
  self.vocab_size = config.vocab_size
@@ -2155,7 +2155,7 @@ from .configuration_gemma import CeruleGemmaConfig
2155
  class CeruleGemmaModel(CeruleMetaModel, GemmaModel):
2156
  config_class = CeruleGemmaConfig
2157
 
2158
- def __init__(self, config: GemmaConfig):
2159
  super(CeruleGemmaModel, self).__init__(config)
2160
 
2161
 
@@ -2264,5 +2264,5 @@ class CeruleGemmaForCausalLM(GemmaForCausalLM, CeruleMetaForCausalLM):
2264
  return new_images
2265
 
2266
 
2267
- AutoConfig.register("cerule-gemma", CeruleGemmaConfig)
2268
  AutoModelForCausalLM.register(CeruleGemmaConfig, CeruleGemmaForCausalLM)
 
853
  replace_return_docstrings,
854
  )
855
  from transformers.utils.import_utils import is_torch_fx_available
856
+ from .configuration_gemma import CeruleGemmaConfig
857
 
858
 
859
  if is_flash_attn_2_available():
 
872
 
873
  logger = logging.get_logger(__name__)
874
 
875
+ _CONFIG_FOR_DOC = "CeruleGemmaConfig"
876
 
877
 
878
  def _get_unpad_data(attention_mask):
 
1003
  """Multi-headed attention from 'Attention Is All You Need' paper"""
1004
 
1005
  # Ignore copy
1006
+ def __init__(self, config: CeruleGemmaConfig, layer_idx: Optional[int] = None):
1007
  super().__init__()
1008
  self.config = config
1009
  self.layer_idx = layer_idx
 
1396
 
1397
  # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
1398
  class GemmaDecoderLayer(nn.Module):
1399
+ def __init__(self, config: CeruleGemmaConfig, layer_idx: int):
1400
  super().__init__()
1401
  self.hidden_size = config.hidden_size
1402
 
 
1480
  and behavior.
1481
 
1482
  Parameters:
1483
+ config ([`CeruleGemmaConfig`]):
1484
  Model configuration class with all the parameters of the model. Initializing with a config file does not
1485
  load the weights associated with the model, only the configuration. Check out the
1486
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
 
1492
  GEMMA_START_DOCSTRING,
1493
  )
1494
  class GemmaPreTrainedModel(PreTrainedModel):
1495
+ config_class = CeruleGemmaConfig
1496
  base_model_prefix = "model"
1497
  supports_gradient_checkpointing = True
1498
  _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
 
1618
  config: GemmaConfig
1619
  """
1620
 
1621
+ def __init__(self, config: CeruleGemmaConfig):
1622
  super().__init__(config)
1623
  self.padding_idx = config.pad_token_id
1624
  self.vocab_size = config.vocab_size
 
2155
  class CeruleGemmaModel(CeruleMetaModel, GemmaModel):
2156
  config_class = CeruleGemmaConfig
2157
 
2158
+ def __init__(self, config: CeruleGemmaConfig):
2159
  super(CeruleGemmaModel, self).__init__(config)
2160
 
2161
 
 
2264
  return new_images
2265
 
2266
 
2267
+ AutoConfig.register("phi-msft", CeruleGemmaConfig)
2268
  AutoModelForCausalLM.register(CeruleGemmaConfig, CeruleGemmaForCausalLM)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ flash_attn
2
+ transformers>=4.39.1