feynmodel / configuration_feynmodel.py
Imagroune's picture
Pushing 1
a3f9aa4 verified
raw
history blame
5.22 kB
from transformers import PretrainedConfig
import copy
class Florence2VisionConfig(PretrainedConfig):
model_type = "florence2_vision"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
drop_path_rate=0.1,
patch_size=[7, 3, 3, 3],
patch_stride=[4, 2, 2, 2],
patch_padding=[3, 1, 1, 1],
patch_prenorm=[False, True, True, True],
enable_checkpoint=False,
dim_embed=[256, 512, 1024, 2048],
num_heads=[8, 16, 32, 64],
num_groups=[8, 16, 32, 64],
depths=[1, 1, 9, 1],
window_size=12,
projection_dim=1024,
visual_temporal_embedding=None,
image_pos_embed=None,
image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
**kwargs,
):
self.drop_path_rate = drop_path_rate
self.patch_size = patch_size
self.patch_stride = patch_stride
self.patch_padding = patch_padding
self.patch_prenorm = patch_prenorm
self.enable_checkpoint = enable_checkpoint
self.dim_embed = dim_embed
self.num_heads = num_heads
self.num_groups = num_groups
self.depths = depths
self.window_size = window_size
self.projection_dim = projection_dim
self.visual_temporal_embedding = visual_temporal_embedding
self.image_pos_embed = image_pos_embed
self.image_feature_source = image_feature_source
super().__init__(**kwargs)
class Gemma2Config(PretrainedConfig):
model_type = "gemma2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
query_pre_attn_scalar=224,
sliding_window=4096,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_activation = hidden_activation
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.attn_logit_softcapping = attn_logit_softcapping
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.final_logit_softcapping = final_logit_softcapping
self.query_pre_attn_scalar = query_pre_attn_scalar
self.sliding_window = sliding_window
self.cache_implementation = "hybrid"
class FeynModelConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`FeynModel`]. It is used to instantiate a FeynModel
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma2-2B + Florence-2-Base + FeynModel V0.1.0.
```python
>>> from transformers import FeynModel, FeynModelConfig
>>> # Initializing a FeynModel style configuration
>>> configuration = FeynModelConfig()
>>> # Initializing a model
>>> model = FeynModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
# model_type = "gemma2"
# is_composition = False
model_type = "FeynModel"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
vocab_size=256000,
projection_dim=1024,
**kwargs,
):
self.ignore_index = ignore_index
self.vocab_size = vocab_size
self.projection_dim = projection_dim
self.vision_config = vision_config
self.vocab_size = self.vocab_size
self.text_config = text_config
# self.sliding_window = text_config.sliding_window
# Ajout des attributs de text_config à l'instance actuelle de Config
if text_config is not None:
for attr, value in text_config.items():
setattr(self, attr, value)
super().__init__(**kwargs)