# modeling_fynmodel : Imed MAGROUNE / 2024 - 09 # original code from modeling_FeynModel # add DaVit Vision Tower # # update generate forward function # # add lora adapters # # train on coco OD and vision reasoning # train on ScenceQA # # todo add mamaba layer # # todo train on Arc-AGI from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, logging, replace_return_docstrings, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, ) from transformers.activations import ACT2FN from transformers.modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) from transformers.cache_utils import Cache, HybridCache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from typing import List, Optional, Tuple, Union from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2ForCausalLM,Gemma2DecoderLayer,Gemma2RMSNorm from configuration_feynmodel import FeynModelConfig,Florence2VisionConfig from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM import json import math import torch from torch import nn import torch.nn.functional as F import logging from transformers.utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, logging, replace_return_docstrings, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, ) from transformers.modeling_utils import PreTrainedModel from collections import OrderedDict from einops import rearrange from timm.models.layers import DropPath, trunc_normal_ logger = logging.get_logger(__name__) class MySequential(nn.Sequential): def forward(self, *inputs): for module in self._modules.values(): if type(inputs) == tuple: inputs = module(*inputs) else: inputs = module(inputs) return inputs class PreNorm(nn.Module): def __init__(self, norm, fn, drop_path=None): super().__init__() self.norm = norm self.fn = fn self.drop_path = drop_path def forward(self, x, *args, **kwargs): shortcut = x if self.norm != None: x, size = self.fn(self.norm(x), *args, **kwargs) else: x, size = self.fn(x, *args, **kwargs) if self.drop_path: x = self.drop_path(x) x = shortcut + x return x, size class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.net = nn.Sequential(OrderedDict([ ("fc1", nn.Linear(in_features, hidden_features)), ("act", act_layer()), ("fc2", nn.Linear(hidden_features, out_features)) ])) def forward(self, x, size): return self.net(x), size class DepthWiseConv2d(nn.Module): def __init__( self, dim_in, kernel_size, padding, stride, bias=True, ): super().__init__() self.dw = nn.Conv2d( dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride, bias=bias ) def forward(self, x, size): B, N, C = x.shape H, W = size assert N == H * W x = self.dw(x.transpose(1, 2).view(B, C, H, W)) size = (x.size(-2), x.size(-1)) x = x.flatten(2).transpose(1, 2) return x, size class ConvEmbed(nn.Module): """ Image to Patch Embedding """ def __init__( self, patch_size=7, in_chans=3, embed_dim=64, stride=4, padding=2, norm_layer=None, pre_norm=True ): super().__init__() self.patch_size = patch_size self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding ) dim_norm = in_chans if pre_norm else embed_dim self.norm = norm_layer(dim_norm) if norm_layer else None self.pre_norm = pre_norm def forward(self, x, size): H, W = size if len(x.size()) == 3: if self.norm and self.pre_norm: x = self.norm(x) x = rearrange( x, 'b (h w) c -> b c h w', h=H, w=W ) x = self.proj(x) _, _, H, W = x.shape x = rearrange(x, 'b c h w -> b (h w) c') if self.norm and not self.pre_norm: x = self.norm(x) return x, (H, W) class ChannelAttention(nn.Module): def __init__(self, dim, groups=8, qkv_bias=True): super().__init__() self.groups = groups self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) def forward(self, x, size): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * (float(N) ** -0.5) attention = q.transpose(-1, -2) @ k attention = attention.softmax(dim=-1) x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x, size class ChannelBlock(nn.Module): def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True): super().__init__() drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None self.channel_attn = PreNorm( norm_layer(dim), ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), drop_path ) self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None self.ffn = PreNorm( norm_layer(dim), Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), drop_path ) def forward(self, x, size): if self.conv1: x, size = self.conv1(x, size) x, size = self.channel_attn(x, size) if self.conv2: x, size = self.conv2(x, size) x, size = self.ffn(x, size) return x, size def window_partition(x, window_size: int): B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): B = batch_size # this will cause onnx conversion failed for dynamic axis, because treated as constant # int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): def __init__(self, dim, num_heads, window_size, qkv_bias=True): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = float(head_dim) ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.softmax = nn.Softmax(dim=-1) def forward(self, x, size): H, W = size B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape x = window_partition(x, self.window_size) x = x.view(-1, self.window_size * self.window_size, C) # W-MSA/SW-MSA # attn_windows = self.attn(x_windows) B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = (q @ k.transpose(-2, -1)) attn = self.softmax(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) # merge windows x = x.view( -1, self.window_size, self.window_size, C ) x = window_reverse(x, B, self.window_size, Hp, Wp) if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) return x, size class SpatialBlock(nn.Module): def __init__(self, dim, num_heads, window_size, mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True): super().__init__() drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None self.window_attn = PreNorm( norm_layer(dim), WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), drop_path ) self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None self.ffn = PreNorm( norm_layer(dim), Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), drop_path ) def forward(self, x, size): if self.conv1: x, size = self.conv1(x, size) x, size = self.window_attn(x, size) if self.conv2: x, size = self.conv2(x, size) x, size = self.ffn(x, size) return x, size class DaViT(nn.Module): """ DaViT: Dual-Attention Transformer Args: in_chans (int): Number of input image channels. Default: 3. num_classes (int): Number of classes for classification head. Default: 1000. patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2). patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2). patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0). patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False). embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256). num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16). num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16). window_size (int): Window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. drop_path_rate (float): Stochastic depth rate. Default: 0.1. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. enable_checkpoint (bool): If True, enable checkpointing. Default: False. conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True. conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True. """ def __init__( self, in_chans=3, num_classes=1000, depths=(1, 1, 3, 1), patch_size=(7, 2, 2, 2), patch_stride=(4, 2, 2, 2), patch_padding=(3, 0, 0, 0), patch_prenorm=(False, False, False, False), embed_dims=(64, 128, 192, 256), num_heads=(3, 6, 12, 24), num_groups=(3, 6, 12, 24), window_size=7, mlp_ratio=4., qkv_bias=True, drop_path_rate=0.1, norm_layer=nn.LayerNorm, enable_checkpoint=False, conv_at_attn=True, conv_at_ffn=True, ): super().__init__() self.num_classes = num_classes self.embed_dims = embed_dims self.num_heads = num_heads self.num_groups = num_groups self.num_stages = len(self.embed_dims) self.enable_checkpoint = enable_checkpoint assert self.num_stages == len(self.num_heads) == len(self.num_groups) num_stages = len(embed_dims) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)] depth_offset = 0 convs = [] blocks = [] for i in range(num_stages): conv_embed = ConvEmbed( patch_size=patch_size[i], stride=patch_stride[i], padding=patch_padding[i], in_chans=in_chans if i == 0 else self.embed_dims[i - 1], embed_dim=self.embed_dims[i], norm_layer=norm_layer, pre_norm=patch_prenorm[i] ) convs.append(conv_embed) block = MySequential( *[ MySequential(OrderedDict([ ( 'spatial_block', SpatialBlock( embed_dims[i], num_heads[i], window_size, drop_path_rate=dpr[depth_offset+j*2], qkv_bias=qkv_bias, mlp_ratio=mlp_ratio, conv_at_attn=conv_at_attn, conv_at_ffn=conv_at_ffn, ) ), ( 'channel_block', ChannelBlock( embed_dims[i], num_groups[i], drop_path_rate=dpr[depth_offset+j*2+1], qkv_bias=qkv_bias, mlp_ratio=mlp_ratio, conv_at_attn=conv_at_attn, conv_at_ffn=conv_at_ffn, ) ) ])) for j in range(depths[i]) ] ) blocks.append(block) depth_offset += depths[i]*2 self.convs = nn.ModuleList(convs) self.blocks = nn.ModuleList(blocks) self.norms = norm_layer(self.embed_dims[-1]) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) @property def dim_out(self): return self.embed_dims[-1] def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, std=0.02) for name, _ in m.named_parameters(): if name in ['bias']: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0) def forward_features_unpool(self, x): """ forward until avg pooling Args: x (_type_): input image tensor """ input_size = (x.size(2), x.size(3)) for conv, block in zip(self.convs, self.blocks): x, input_size = conv(x, input_size) if self.enable_checkpoint: x, input_size = checkpoint.checkpoint(block, x, input_size) else: x, input_size = block(x, input_size) return x def forward_features(self, x): x = self.forward_features_unpool(x) # (batch_size, num_tokens, token_dim) x = self.avgpool(x.transpose(1, 2)) # (batch_size, 1, num_tokens) x = torch.flatten(x, 1) x = self.norms(x) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x @classmethod def from_config(cls, config): return cls( depths=config.depths, embed_dims=config.dim_embed, num_heads=config.num_heads, num_groups=config.num_groups, patch_size=config.patch_size, patch_stride=config.patch_stride, patch_padding=config.patch_padding, patch_prenorm=config.patch_prenorm, drop_path_rate=config.drop_path_rate, window_size=config.window_size, ) _CONFIG_FOR_DOC = "FeynModelConfig" FEYNMODEL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`FeynModelConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ FEYNMODEL_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, min_dtype: float, cache_position: torch.Tensor, batch_size: int, ): #print(f" +++++++++ prepare 4K +++++++++++++++ rec {attention_mask.size()} sequence_length {sequence_length}") if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. #print("+++++++++++++++++ return it") #causal_mask = attention_mask # In this case we assume that the mask comes already in inverted form. causal_mask = attention_mask[:, :, -sequence_length:, :] #print(f"+++++++++++++++++ truncated causal_mask to last {sequence_length} elements, size: {causal_mask.size()}") #print(f"+++++++++++++++++ return it causal_mask {causal_mask.size()} !!!!!!!!! attention_mask {attention_mask.size()}") else: #print("+++++++++++++++++++++ else +++++++++++++++++") causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) #print(f"++++++++++++++++ causal_mask {causal_mask.size()} ++++++++++++++++++ sequence_length = {sequence_length} ") if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) #print(f"++++++++++++++++++ causal_mask = torch.triu ++++++++++ {causal_mask.size()} ") causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) #print(f"+++++++++++++++++++++ avant if attention_mask is not None:, causal_mask={causal_mask.size()}") if attention_mask is not None: #print(" +++++++++++++ attention_mask is None++++++++++++") causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) #print(f"+++++++++++++++++++ 4K returning causal_mask {causal_mask.size()} +++++++++++++++++++") return causal_mask class LearnedAbsolutePositionEmbedding2D(nn.Module): """ This module learns positional embeddings up to a fixed maximum size. """ def __init__(self, embedding_dim=256, num_pos=50): super().__init__() self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2)) def forward(self, pixel_values): """ pixel_values: (batch_size, height, width, num_channels) returns: (batch_size, height, width, embedding_dim * 2) """ if len(pixel_values.shape) != 4: raise ValueError('pixel_values must be a 4D tensor') height, width = pixel_values.shape[1:3] width_values = torch.arange(width, device=pixel_values.device) height_values = torch.arange(height, device=pixel_values.device) x_emb = self.column_embeddings(width_values) y_emb = self.row_embeddings(height_values) # (height, width, embedding_dim * 2) pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) # (embedding_dim * 2, height, width) pos = pos.permute(2, 0, 1) pos = pos.unsqueeze(0) # (batch_size, embedding_dim * 2, height, width) pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) # (batch_size, height, width, embedding_dim * 2) pos = pos.permute(0, 2, 3, 1) return pos class PositionalEmbeddingCosine1D(nn.Module): """ This class implements a very simple positional encoding. It follows closely the encoder from the link below: https://pytorch.org/tutorials/beginner/translation_transformer.html Args: embed_dim: The dimension of the embeddings. dropout_prob: The dropout probability. max_seq_len: The maximum length to precompute the positional encodings. """ def __init__( self, embed_dim: int = 512, max_seq_len: int = 1024) -> None: super(PositionalEmbeddingCosine1D, self).__init__() self.embed_dim = embed_dim self.max_seq_len = max_seq_len # Generate the sinusoidal arrays. factor = math.log(10000) denominator = torch.exp( -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim) # Matrix where rows correspond to a positional embedding as a function # of the position index (i.e., the row index). frequencies = \ torch.arange(0, self.max_seq_len) \ .reshape(self.max_seq_len, 1) * denominator pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) # Populate uneven entries. pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) # Save the positional embeddings in a constant buffer. self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: """ Args: seq_embeds: The sequence embeddings in order. Allowed size: 1. [T, D], where T is the length of the sequence, and D is the frame embedding dimension. 2. [B, T, D], where B is the batch size and T and D are the same as above. Returns a tensor of with the same dimensions as the input: i.e., [1, T, D] or [T, D]. """ shape_len = len(seq_embeds.shape) assert 2 <= shape_len <= 3 len_seq = seq_embeds.size(-2) assert len_seq <= self.max_seq_len pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] # Adapt pre-computed positional embeddings to the input. if shape_len == 3: pos_embeds = pos_embeds.view( (1, pos_embeds.size(0), pos_embeds.size(1))) return pos_embeds class LearnedAbsolutePositionEmbedding1D(nn.Module): """ Learnable absolute positional embeddings for 1D sequences. Args: embed_dim: The dimension of the embeddings. max_seq_len: The maximum length to precompute the positional encodings. """ def __init__( self, embedding_dim: int = 512, num_pos: int = 1024) -> None: super(LearnedAbsolutePositionEmbedding1D, self).__init__() self.embeddings = nn.Embedding(num_pos, embedding_dim) self.num_pos = num_pos def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: """ Args: seq_embeds: The sequence embeddings in order. Allowed size: 1. [T, D], where T is the length of the sequence, and D is the frame embedding dimension. 2. [B, T, D], where B is the batch size and T and D are the same as above. Returns a tensor of with the same dimensions as the input: i.e., [1, T, D] or [T, D]. """ shape_len = len(seq_embeds.shape) assert 2 <= shape_len <= 3 len_seq = seq_embeds.size(-2) assert len_seq <= self.num_pos # [T, D] pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device)) # Adapt pre-computed positional embeddings to the input. if shape_len == 3: pos_embeds = pos_embeds.view( (1, pos_embeds.size(0), pos_embeds.size(1))) return pos_embeds def create_git_attention_mask( tgt: torch.Tensor, memory: torch.Tensor, max_length: int ) -> torch.Tensor: # Obtain the dimensions of the target text and memory batch_size = tgt.size(0) num_tgt = tgt.shape[1] num_memory = memory.shape[1] total_length = num_memory + num_tgt # Create the top left part of the attention matrix top_left = torch.zeros((num_memory, num_memory)) # Attention enabled in this region top_right = torch.full((num_memory, num_tgt), float(-3.4028e+38)) # Attention disabled here # Bottom left part of the attention matrix bottom_left = torch.zeros((num_tgt, num_memory)) # Attention enabled here # Create a lower triangular matrix for the bottom right part bottom_right = torch.tril(torch.ones(num_tgt, num_tgt)) # Transform 1s to 0 to enable attention, and 0s to -inf to block attention bottom_right = bottom_right.masked_fill(bottom_right == 0, float(-3.4028e+38)) bottom_right = bottom_right.masked_fill(bottom_right == 1, float(0)) # Concatenate matrices to form the full mask left = torch.cat((top_left, bottom_left), dim=0) right = torch.cat((top_right, bottom_right), dim=0) # Combine left and right parts full_attention_mask = torch.cat((left, right), dim=1) # Add padding to reach max_length padding = torch.full((total_length, max_length - total_length), float(-3.4028e+38)) full_attention_mask = torch.cat((full_attention_mask, padding), dim=1) # Add an axis for multi-heads and batch_size full_attention_mask = full_attention_mask[None, None, :, :] # Expand the mask to have shape (batch_size, 1, seq_length, max_length) full_attention_mask = full_attention_mask.expand(batch_size, 1, full_attention_mask.size(-2), full_attention_mask.size(-1)) return full_attention_mask def get_position_ids_from_binary_attention_mask(mask): """ Extract position IDs from a binary attention mask. Args: mask (torch.Tensor): The attention mask tensor of shape (1, 1, seq_len, seq_len), where 1 indicates allowed attention and 0 indicates blocked attention. Returns: list: A list of lists where each sublist contains the allowed position IDs for each query position. """ # Assuming the mask is of shape (1, 1, seq_len, seq_len) _, _, seq_len, _ = mask.shape # Create a tensor with position IDs from 0 to seq_len - 1 position_ids = torch.arange(seq_len, dtype=torch.long, device=mask.device) # Add a batch dimension position_ids = position_ids.unsqueeze(0) return position_ids def ensure_tensor(variable): # Check if the variable is a torch.Tensor if isinstance(variable, torch.Tensor): # print("Variable is already a tensor.") return variable else: #print("Variable is not a tensor, converting...") try: # Convert the variable to a tensor tensor = torch.tensor(variable) #print("Conversion successful.") return tensor except Exception as e: print(f"Error converting to tensor: {e}") raise @add_start_docstrings( "The bare Model outputting raw hidden-states without any specific head on top.", FEYNMODEL_START_DOCSTRING, ) class FeynModel(Gemma2Model): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`FeynModelDecoderLayer`] + ['LoraLayer'] for *proj* moduls NB : LoraLayers will be added and activatd on proj modules onpy if pixel_values is not None Args: config: FeynModelConfig """ def __init__(self, config: FeynModelConfig): super().__init__(config) # Initialize weights and apply final processing self.mode='llm' ''' self.image_patch_tokens = int( (config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1 ) if config.num_image_with_embedding is not None: self.image_patch_tokens *= config.num_image_with_embedding ''' self.image_patch_tokens = 577 self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(FEYNMODEL_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, causal_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: # print(f" self.mode = {self.mode}") # Ensure cache_position is initialized if not provided if cache_position is None: batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0) cache_position = torch.zeros((batch_size,), dtype=torch.long, device=input_ids.device if input_ids is not None else inputs_embeds.device) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) else: causal_mask = ensure_tensor(causal_attention_mask) position_ids = get_position_ids_from_binary_attention_mask(attention_mask) #print(f" causal_mask = {causal_mask} ") if cache_position is None: cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None : position_ids = cache_position.unsqueeze(0) # Convert position_ids to a tensor if not already if not isinstance(position_ids, torch.Tensor): position_ids = torch.tensor(position_ids, dtype=torch.long, device=inputs_embeds.device) # embed positions hidden_states = inputs_embeds # normalized # FeynModel downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = past_key_values if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): # print(f" _start _____ _update_causal_mask attention_mask {attention_mask.size()} {attention_mask} ") # Flash Attention currently doesn't support static cache but FeynModel work only with static cache. # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible # as it doesn't cause dynamic control issues. if self.config._attn_implementation == "flash_attention_2": return attention_mask dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if isinstance(past_key_values, HybridCache): target_length = past_key_values.get_max_length() else: target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) #print(f" _end ______ _update_causal_mask causal_mask {causal_mask.size()} {causal_mask} ") return causal_mask class FeynModelForCausalLM(Gemma2ForCausalLM): _tied_weights_keys = ["lm_head.weight"] config_class = FeynModelConfig def __init__(self, config): super().__init__(config) config.vision_config=Florence2VisionConfig.from_dict(config.vision_config) self.model = FeynModel(config) # assert config.vision_config.model_type== 'davit', 'only DaViT is supported for now' self.vision_tower = DaViT.from_config(config=config.vision_config) self._build_image_projection_layers(config) self.__causal_attention_mask = None # Initialize weights and apply final processing self.post_init() ################ Vision Tower ######################## def _build_image_projection_layers(self, config): image_dim_out = config.vision_config.dim_embed[-1] dim_projection = config.vision_config.projection_dim self.image_projection = nn.Parameter( torch.empty(image_dim_out, dim_projection) ) self.image_proj_norm = nn.LayerNorm(dim_projection) image_pos_embed_config = config.vision_config.image_pos_embed if image_pos_embed_config['type'] == 'learned_abs_2d': self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( embedding_dim=image_dim_out, num_pos=image_pos_embed_config['max_pos_embeddings'] ) else: raise NotImplementedError('Not implemented yet') self.image_feature_source = config.vision_config.image_feature_source # temporal embedding visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding if visual_temporal_embedding_config['type'] == 'COSINE': self.visual_temporal_embed = PositionalEmbeddingCosine1D( embed_dim=image_dim_out, max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings'] ) else: raise NotImplementedError('Not implemented yet') def _merge_input_ids_with_image_features(self, image_features, inputs_embeds): batch_size, image_token_length = image_features.size()[:-1] device = image_features.device image_attention_mask = torch.ones(batch_size, image_token_length, device=device) if inputs_embeds is None: return image_features, image_attention_mask task_prefix_embeds = inputs_embeds task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device) # Assurer que les masques d'attention sont de deux dimensions if len(task_prefix_attention_mask.shape) == 3: task_prefix_attention_mask = task_prefix_attention_mask.squeeze(1) # Vérifier la dimension de batch et ajuster si nécessaire if image_features.size(0) != task_prefix_embeds.size(0): raise ValueError("Batch sizes of image_features and task_prefix_embeds do not match") # Ajouter une dimension fictive si les dimensions ne sont pas alignées if image_features.dim() < task_prefix_embeds.dim(): image_features = image_features.unsqueeze(-1) elif task_prefix_embeds.dim() < image_features.dim(): task_prefix_embeds = task_prefix_embeds.unsqueeze(-1) # Assurer que toutes les dimensions, sauf dim=1, sont identiques if image_features.size(2) != task_prefix_embeds.size(2): # Ajuster ou signaler une erreur si les dimensions internes ne sont pas compatibles raise ValueError("Internal dimensions of image_features and task_prefix_embeds do not match") inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1) attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1) return inputs_embeds, attention_mask def _encode_image(self, pixel_values): if len(pixel_values.shape) == 4: batch_size, C, H, W = pixel_values.shape T = 1 x = self.vision_tower.forward_features_unpool(pixel_values) else: # Ajoute une dimension de batch au début si 'pixel_values' n'a que 3 dimensions (C, H, W) pixel_values = pixel_values.unsqueeze(0) # Ajoute une dimension de batch batch_size, C, H, W = pixel_values.shape T = 1 x = self.vision_tower.forward_features_unpool(pixel_values) if self.image_pos_embed is not None: x = x.view(batch_size * T, -1, x.shape[-1]) num_tokens = x.shape[-2] h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) assert h * w == num_tokens, 'only support square feature maps for now' x = x.view(batch_size * T, h, w, x.shape[-1]) pos_embed = self.image_pos_embed(x) x = x + pos_embed x = x.view(batch_size, T * h*w, x.shape[-1]) if self.visual_temporal_embed is not None: visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) x_feat_dict = {} spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] x_feat_dict['last_frame'] = x new_x = [] for _image_feature_source in self.image_feature_source: if _image_feature_source not in x_feat_dict: raise ValueError('invalid image feature source: {}'.format(_image_feature_source)) new_x.append(x_feat_dict[_image_feature_source]) x = torch.cat(new_x, dim=1) x = x @ self.image_projection x = self.image_proj_norm(x) return x ####################################################### def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(FEYNMODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, pixel_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, GemmaForCausalLM >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") >>> prompt = "What is your favorite condiment?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?" ```""" if self.training and self.config._attn_implementation != "eager": logger.warning_once( "It is strongly recommended to train FeynModel models with the `eager` attention implementation " f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is not None: self.model.mode='vlm' if input_ids is not None: inputs_embeds = self.get_input_embeddings()(input_ids) image_features = self._encode_image(pixel_values) inputs_embeds, causal_attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds ) causal_attention_mask = create_git_attention_mask(tgt=input_ids, memory=image_features,max_length=2048) causal_attention_mask=causal_attention_mask.to(input_ids.device) self.__causal_attention_mask=causal_attention_mask # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) if pixel_values is not None: outputs = self.model( input_ids=None, attention_mask=causal_attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, causal_attention_mask=causal_attention_mask, ) else: outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, causal_attention_mask=self.__causal_attention_mask, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) if self.config.final_logit_softcapping is not None: logits = logits / self.config.final_logit_softcapping logits = torch.tanh(logits) logits = logits * self.config.final_logit_softcapping logits = logits.float() loss = None if labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one num_image_tokens = self.model.image_patch_tokens shifted_logits = logits[:, num_image_tokens:-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss() loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: # print(f"+-+-+-+-+-+-+++ past_key_values +-+-+++- position_ids {position_ids.size()} ================= ") position_ids = position_ids[:, -input_ids.shape[1] :] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the # batch size = 1 case, `position_ids` is already contiguous but with varying stride # which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) # print(f"+-+-+-+-+-+-+++ past_key_values +-+-+++- position_ids cmlone ==> {position_ids.size()} ================= ") # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: #print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> first generation step>>>>>>>>>>>>>>>>>>>>>>>>>>>>>><") model_inputs = {"inputs_embeds": inputs_embeds} else: # The clone here is for the same reason as for `position_ids`. # print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> The clone here is for the same reason as for `position_ids` ==> input_ids input_ids.clone.>>>>>>>>>>>>>>>>>>>>>>>>>>>>>><") model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2: if inputs_embeds is not None and input_ids.size(1)!= 0 : ###################### V ############## add _ for _ = inputs_embeds.shape batch_size, sequence_length, _ = inputs_embeds.shape device = inputs_embeds.device #print(f"1111111 +-+-+-+-+-+-+-+-+-+- sequence_length = inputs_embeds {sequence_length}") else: batch_size, sequence_length = position_ids.shape device = input_ids.device #print(f"22222222 +-+-+-+-+-+-+-+-+-+- sequence_length = input_ids.shape {sequence_length}") dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_length(), dtype=dtype, device=device, min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } ) return model_inputs def generate( self, input_ids, pixel_values=None, max_length=None, do_sample=True, temperature=0.7, **kwargs ): print("Fonction generate personnalisée appelée") if pixel_values is not None: if input_ids is not None: print("input") inputs_embeds = self.get_input_embeddings()(input_ids) print("pixels") image_features = self._encode_image(pixel_values) inputs_embeds, causal_attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds ) causal_attention_mask = create_git_attention_mask(tgt=input_ids, memory=image_features,max_length=max_length) causal_attention_mask=causal_attention_mask.to(input_ids.device) self.__causal_attention_mask=causal_attention_mask self.model.mode='vlm' result = super().generate( input_ids=None, inputs_embeds=inputs_embeds, max_length=max_length, do_sample=do_sample, temperature=temperature, **kwargs ) else: print("llm") self.model.mode=='llm' result = super().generate( input_ids=input_ids, #inputs_embeds=None, max_length=max_length, do_sample=do_sample, temperature=temperature, **kwargs ) self.__causal_attention_mask = None return result