feynmodel / modeling_feynmodel.py
Imagroune's picture
Pushing 1
a3f9aa4 verified
raw
history blame
61.6 kB
# modeling_fynmodel : Imed MAGROUNE / 2024 - 09
# original code from modeling_FeynModel
# add DaVit Vision Tower
#
# update generate forward function
#
# add lora adapters
#
# train on coco OD and vision reasoning
# train on ScenceQA
#
# todo add mamaba layer
#
# todo train on Arc-AGI
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging,
replace_return_docstrings,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
)
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from transformers.cache_utils import Cache, HybridCache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from typing import List, Optional, Tuple, Union
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2ForCausalLM,Gemma2DecoderLayer,Gemma2RMSNorm
from .configuration_feynmodel import FeynModelConfig,Florence2VisionConfig
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
import json
import math
import torch
from torch import nn
import torch.nn.functional as F
import logging
from transformers.utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging,
replace_return_docstrings,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
)
from transformers.modeling_utils import PreTrainedModel
from collections import OrderedDict
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_
logger = logging.get_logger(__name__)
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class PreNorm(nn.Module):
def __init__(self, norm, fn, drop_path=None):
super().__init__()
self.norm = norm
self.fn = fn
self.drop_path = drop_path
def forward(self, x, *args, **kwargs):
shortcut = x
if self.norm != None:
x, size = self.fn(self.norm(x), *args, **kwargs)
else:
x, size = self.fn(x, *args, **kwargs)
if self.drop_path:
x = self.drop_path(x)
x = shortcut + x
return x, size
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.net = nn.Sequential(OrderedDict([
("fc1", nn.Linear(in_features, hidden_features)),
("act", act_layer()),
("fc2", nn.Linear(hidden_features, out_features))
]))
def forward(self, x, size):
return self.net(x), size
class DepthWiseConv2d(nn.Module):
def __init__(
self,
dim_in,
kernel_size,
padding,
stride,
bias=True,
):
super().__init__()
self.dw = nn.Conv2d(
dim_in, dim_in,
kernel_size=kernel_size,
padding=padding,
groups=dim_in,
stride=stride,
bias=bias
)
def forward(self, x, size):
B, N, C = x.shape
H, W = size
assert N == H * W
x = self.dw(x.transpose(1, 2).view(B, C, H, W))
size = (x.size(-2), x.size(-1))
x = x.flatten(2).transpose(1, 2)
return x, size
class ConvEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(
self,
patch_size=7,
in_chans=3,
embed_dim=64,
stride=4,
padding=2,
norm_layer=None,
pre_norm=True
):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size,
stride=stride,
padding=padding
)
dim_norm = in_chans if pre_norm else embed_dim
self.norm = norm_layer(dim_norm) if norm_layer else None
self.pre_norm = pre_norm
def forward(self, x, size):
H, W = size
if len(x.size()) == 3:
if self.norm and self.pre_norm:
x = self.norm(x)
x = rearrange(
x, 'b (h w) c -> b c h w',
h=H, w=W
)
x = self.proj(x)
_, _, H, W = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
if self.norm and not self.pre_norm:
x = self.norm(x)
return x, (H, W)
class ChannelAttention(nn.Module):
def __init__(self, dim, groups=8, qkv_bias=True):
super().__init__()
self.groups = groups
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x, size):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * (float(N) ** -0.5)
attention = q.transpose(-1, -2) @ k
attention = attention.softmax(dim=-1)
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x, size
class ChannelBlock(nn.Module):
def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True,
drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
conv_at_attn=True, conv_at_ffn=True):
super().__init__()
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
self.channel_attn = PreNorm(
norm_layer(dim),
ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
drop_path
)
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
self.ffn = PreNorm(
norm_layer(dim),
Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
drop_path
)
def forward(self, x, size):
if self.conv1:
x, size = self.conv1(x, size)
x, size = self.channel_attn(x, size)
if self.conv2:
x, size = self.conv2(x, size)
x, size = self.ffn(x, size)
return x, size
def window_partition(x, window_size: int):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
B = batch_size
# this will cause onnx conversion failed for dynamic axis, because treated as constant
# int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, num_heads, window_size, qkv_bias=True):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = float(head_dim) ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, size):
H, W = size
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x = window_partition(x, self.window_size)
x = x.view(-1, self.window_size * self.window_size, C)
# W-MSA/SW-MSA
# attn_windows = self.attn(x_windows)
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
# merge windows
x = x.view(
-1, self.window_size, self.window_size, C
)
x = window_reverse(x, B, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
return x, size
class SpatialBlock(nn.Module):
def __init__(self, dim, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU,
norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True):
super().__init__()
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
self.window_attn = PreNorm(
norm_layer(dim),
WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
drop_path
)
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
self.ffn = PreNorm(
norm_layer(dim),
Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
drop_path
)
def forward(self, x, size):
if self.conv1:
x, size = self.conv1(x, size)
x, size = self.window_attn(x, size)
if self.conv2:
x, size = self.conv2(x, size)
x, size = self.ffn(x, size)
return x, size
class DaViT(nn.Module):
""" DaViT: Dual-Attention Transformer
Args:
in_chans (int): Number of input image channels. Default: 3.
num_classes (int): Number of classes for classification head. Default: 1000.
patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2).
patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2).
patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0).
patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False).
embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256).
num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16).
num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16).
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True.
drop_path_rate (float): Stochastic depth rate. Default: 0.1.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
enable_checkpoint (bool): If True, enable checkpointing. Default: False.
conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True.
conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True.
"""
def __init__(
self,
in_chans=3,
num_classes=1000,
depths=(1, 1, 3, 1),
patch_size=(7, 2, 2, 2),
patch_stride=(4, 2, 2, 2),
patch_padding=(3, 0, 0, 0),
patch_prenorm=(False, False, False, False),
embed_dims=(64, 128, 192, 256),
num_heads=(3, 6, 12, 24),
num_groups=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
enable_checkpoint=False,
conv_at_attn=True,
conv_at_ffn=True,
):
super().__init__()
self.num_classes = num_classes
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_groups = num_groups
self.num_stages = len(self.embed_dims)
self.enable_checkpoint = enable_checkpoint
assert self.num_stages == len(self.num_heads) == len(self.num_groups)
num_stages = len(embed_dims)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)]
depth_offset = 0
convs = []
blocks = []
for i in range(num_stages):
conv_embed = ConvEmbed(
patch_size=patch_size[i],
stride=patch_stride[i],
padding=patch_padding[i],
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
embed_dim=self.embed_dims[i],
norm_layer=norm_layer,
pre_norm=patch_prenorm[i]
)
convs.append(conv_embed)
block = MySequential(
*[
MySequential(OrderedDict([
(
'spatial_block', SpatialBlock(
embed_dims[i],
num_heads[i],
window_size,
drop_path_rate=dpr[depth_offset+j*2],
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
conv_at_attn=conv_at_attn,
conv_at_ffn=conv_at_ffn,
)
),
(
'channel_block', ChannelBlock(
embed_dims[i],
num_groups[i],
drop_path_rate=dpr[depth_offset+j*2+1],
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
conv_at_attn=conv_at_attn,
conv_at_ffn=conv_at_ffn,
)
)
])) for j in range(depths[i])
]
)
blocks.append(block)
depth_offset += depths[i]*2
self.convs = nn.ModuleList(convs)
self.blocks = nn.ModuleList(blocks)
self.norms = norm_layer(self.embed_dims[-1])
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
@property
def dim_out(self):
return self.embed_dims[-1]
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.02)
for name, _ in m.named_parameters():
if name in ['bias']:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def forward_features_unpool(self, x):
"""
forward until avg pooling
Args:
x (_type_): input image tensor
"""
input_size = (x.size(2), x.size(3))
for conv, block in zip(self.convs, self.blocks):
x, input_size = conv(x, input_size)
if self.enable_checkpoint:
x, input_size = checkpoint.checkpoint(block, x, input_size)
else:
x, input_size = block(x, input_size)
return x
def forward_features(self, x):
x = self.forward_features_unpool(x)
# (batch_size, num_tokens, token_dim)
x = self.avgpool(x.transpose(1, 2))
# (batch_size, 1, num_tokens)
x = torch.flatten(x, 1)
x = self.norms(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
@classmethod
def from_config(cls, config):
return cls(
depths=config.depths,
embed_dims=config.dim_embed,
num_heads=config.num_heads,
num_groups=config.num_groups,
patch_size=config.patch_size,
patch_stride=config.patch_stride,
patch_padding=config.patch_padding,
patch_prenorm=config.patch_prenorm,
drop_path_rate=config.drop_path_rate,
window_size=config.window_size,
)
_CONFIG_FOR_DOC = "FeynModelConfig"
FEYNMODEL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`FeynModelConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
FEYNMODEL_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
#print(f" +++++++++ prepare 4K +++++++++++++++ rec {attention_mask.size()} sequence_length {sequence_length}")
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
#print("+++++++++++++++++ return it")
#causal_mask = attention_mask
# In this case we assume that the mask comes already in inverted form.
causal_mask = attention_mask[:, :, -sequence_length:, :]
#print(f"+++++++++++++++++ truncated causal_mask to last {sequence_length} elements, size: {causal_mask.size()}")
#print(f"+++++++++++++++++ return it causal_mask {causal_mask.size()} !!!!!!!!! attention_mask {attention_mask.size()}")
else:
#print("+++++++++++++++++++++ else +++++++++++++++++")
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
#print(f"++++++++++++++++ causal_mask {causal_mask.size()} ++++++++++++++++++ sequence_length = {sequence_length} ")
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
#print(f"++++++++++++++++++ causal_mask = torch.triu ++++++++++ {causal_mask.size()} ")
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
#print(f"+++++++++++++++++++++ avant if attention_mask is not None:, causal_mask={causal_mask.size()}")
if attention_mask is not None:
#print(" +++++++++++++ attention_mask is None++++++++++++")
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
#print(f"+++++++++++++++++++ 4K returning causal_mask {causal_mask.size()} +++++++++++++++++++")
return causal_mask
class LearnedAbsolutePositionEmbedding2D(nn.Module):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, embedding_dim=256, num_pos=50):
super().__init__()
self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2))
def forward(self, pixel_values):
"""
pixel_values: (batch_size, height, width, num_channels)
returns: (batch_size, height, width, embedding_dim * 2)
"""
if len(pixel_values.shape) != 4:
raise ValueError('pixel_values must be a 4D tensor')
height, width = pixel_values.shape[1:3]
width_values = torch.arange(width, device=pixel_values.device)
height_values = torch.arange(height, device=pixel_values.device)
x_emb = self.column_embeddings(width_values)
y_emb = self.row_embeddings(height_values)
# (height, width, embedding_dim * 2)
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
# (embedding_dim * 2, height, width)
pos = pos.permute(2, 0, 1)
pos = pos.unsqueeze(0)
# (batch_size, embedding_dim * 2, height, width)
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
# (batch_size, height, width, embedding_dim * 2)
pos = pos.permute(0, 2, 3, 1)
return pos
class PositionalEmbeddingCosine1D(nn.Module):
"""
This class implements a very simple positional encoding. It follows closely
the encoder from the link below:
https://pytorch.org/tutorials/beginner/translation_transformer.html
Args:
embed_dim: The dimension of the embeddings.
dropout_prob: The dropout probability.
max_seq_len: The maximum length to precompute the positional encodings.
"""
def __init__(
self,
embed_dim: int = 512,
max_seq_len: int = 1024) -> None:
super(PositionalEmbeddingCosine1D, self).__init__()
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
# Generate the sinusoidal arrays.
factor = math.log(10000)
denominator = torch.exp(
-factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim)
# Matrix where rows correspond to a positional embedding as a function
# of the position index (i.e., the row index).
frequencies = \
torch.arange(0, self.max_seq_len) \
.reshape(self.max_seq_len, 1) * denominator
pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
# Populate uneven entries.
pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
# Save the positional embeddings in a constant buffer.
self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
"""
Args:
seq_embeds: The sequence embeddings in order. Allowed size:
1. [T, D], where T is the length of the sequence, and D is the
frame embedding dimension.
2. [B, T, D], where B is the batch size and T and D are the
same as above.
Returns a tensor of with the same dimensions as the input: i.e.,
[1, T, D] or [T, D].
"""
shape_len = len(seq_embeds.shape)
assert 2 <= shape_len <= 3
len_seq = seq_embeds.size(-2)
assert len_seq <= self.max_seq_len
pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :]
# Adapt pre-computed positional embeddings to the input.
if shape_len == 3:
pos_embeds = pos_embeds.view(
(1, pos_embeds.size(0), pos_embeds.size(1)))
return pos_embeds
class LearnedAbsolutePositionEmbedding1D(nn.Module):
"""
Learnable absolute positional embeddings for 1D sequences.
Args:
embed_dim: The dimension of the embeddings.
max_seq_len: The maximum length to precompute the positional encodings.
"""
def __init__(
self,
embedding_dim: int = 512,
num_pos: int = 1024) -> None:
super(LearnedAbsolutePositionEmbedding1D, self).__init__()
self.embeddings = nn.Embedding(num_pos, embedding_dim)
self.num_pos = num_pos
def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
"""
Args:
seq_embeds: The sequence embeddings in order. Allowed size:
1. [T, D], where T is the length of the sequence, and D is the
frame embedding dimension.
2. [B, T, D], where B is the batch size and T and D are the
same as above.
Returns a tensor of with the same dimensions as the input: i.e.,
[1, T, D] or [T, D].
"""
shape_len = len(seq_embeds.shape)
assert 2 <= shape_len <= 3
len_seq = seq_embeds.size(-2)
assert len_seq <= self.num_pos
# [T, D]
pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device))
# Adapt pre-computed positional embeddings to the input.
if shape_len == 3:
pos_embeds = pos_embeds.view(
(1, pos_embeds.size(0), pos_embeds.size(1)))
return pos_embeds
def create_git_attention_mask(
tgt: torch.Tensor,
memory: torch.Tensor,
max_length: int
) -> torch.Tensor:
# Obtain the dimensions of the target text and memory
batch_size = tgt.size(0)
num_tgt = tgt.shape[1]
num_memory = memory.shape[1]
total_length = num_memory + num_tgt
# Create the top left part of the attention matrix
top_left = torch.zeros((num_memory, num_memory)) # Attention enabled in this region
top_right = torch.full((num_memory, num_tgt), float(-3.4028e+38)) # Attention disabled here
# Bottom left part of the attention matrix
bottom_left = torch.zeros((num_tgt, num_memory)) # Attention enabled here
# Create a lower triangular matrix for the bottom right part
bottom_right = torch.tril(torch.ones(num_tgt, num_tgt))
# Transform 1s to 0 to enable attention, and 0s to -inf to block attention
bottom_right = bottom_right.masked_fill(bottom_right == 0, float(-3.4028e+38))
bottom_right = bottom_right.masked_fill(bottom_right == 1, float(0))
# Concatenate matrices to form the full mask
left = torch.cat((top_left, bottom_left), dim=0)
right = torch.cat((top_right, bottom_right), dim=0)
# Combine left and right parts
full_attention_mask = torch.cat((left, right), dim=1)
# Add padding to reach max_length
padding = torch.full((total_length, max_length - total_length), float(-3.4028e+38))
full_attention_mask = torch.cat((full_attention_mask, padding), dim=1)
# Add an axis for multi-heads and batch_size
full_attention_mask = full_attention_mask[None, None, :, :]
# Expand the mask to have shape (batch_size, 1, seq_length, max_length)
full_attention_mask = full_attention_mask.expand(batch_size, 1, full_attention_mask.size(-2), full_attention_mask.size(-1))
return full_attention_mask
def get_position_ids_from_binary_attention_mask(mask):
"""
Extract position IDs from a binary attention mask.
Args:
mask (torch.Tensor): The attention mask tensor of shape (1, 1, seq_len, seq_len),
where 1 indicates allowed attention and 0 indicates blocked attention.
Returns:
list: A list of lists where each sublist contains the allowed position IDs for each query position.
"""
# Assuming the mask is of shape (1, 1, seq_len, seq_len)
_, _, seq_len, _ = mask.shape
# Create a tensor with position IDs from 0 to seq_len - 1
position_ids = torch.arange(seq_len, dtype=torch.long, device=mask.device)
# Add a batch dimension
position_ids = position_ids.unsqueeze(0)
return position_ids
def ensure_tensor(variable):
# Check if the variable is a torch.Tensor
if isinstance(variable, torch.Tensor):
# print("Variable is already a tensor.")
return variable
else:
#print("Variable is not a tensor, converting...")
try:
# Convert the variable to a tensor
tensor = torch.tensor(variable)
#print("Conversion successful.")
return tensor
except Exception as e:
print(f"Error converting to tensor: {e}")
raise
@add_start_docstrings(
"The bare Model outputting raw hidden-states without any specific head on top.",
FEYNMODEL_START_DOCSTRING,
)
class FeynModel(Gemma2Model):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers.
Each layer is a [`FeynModelDecoderLayer`] + ['LoraLayer'] for *proj* moduls
NB : LoraLayers will be added and activatd on proj modules onpy if pixel_values is not None
Args:
config: FeynModelConfig
"""
def __init__(self, config: FeynModelConfig):
super().__init__(config)
# Initialize weights and apply final processing
self.mode='llm'
'''
self.image_patch_tokens = int(
(config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1
)
if config.num_image_with_embedding is not None:
self.image_patch_tokens *= config.num_image_with_embedding
'''
self.image_patch_tokens = 577
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(FEYNMODEL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
# print(f" self.mode = {self.mode}")
# Ensure cache_position is initialized if not provided
if cache_position is None:
batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
cache_position = torch.zeros((batch_size,), dtype=torch.long, device=input_ids.device if input_ids is not None else inputs_embeds.device)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
else:
causal_mask = ensure_tensor(causal_attention_mask)
position_ids = get_position_ids_from_binary_attention_mask(attention_mask)
#print(f" causal_mask = {causal_mask} ")
if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
if position_ids is None :
position_ids = cache_position.unsqueeze(0)
# Convert position_ids to a tensor if not already
if not isinstance(position_ids, torch.Tensor):
position_ids = torch.tensor(position_ids, dtype=torch.long, device=inputs_embeds.device)
# embed positions
hidden_states = inputs_embeds
# normalized
# FeynModel downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = past_key_values if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# print(f" _start _____ _update_causal_mask attention_mask {attention_mask.size()} {attention_mask} ")
# Flash Attention currently doesn't support static cache but FeynModel work only with static cache.
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
if self.config._attn_implementation == "flash_attention_2":
return attention_mask
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_length()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
#print(f" _end ______ _update_causal_mask causal_mask {causal_mask.size()} {causal_mask} ")
return causal_mask
class FeynModelForCausalLM(Gemma2ForCausalLM):
_tied_weights_keys = ["lm_head.weight"]
config_class = FeynModelConfig
def __init__(self, config):
super().__init__(config)
config.vision_config=Florence2VisionConfig.from_dict(config.vision_config)
self.model = FeynModel(config)
# assert config.vision_config.model_type== 'davit', 'only DaViT is supported for now'
self.vision_tower = DaViT.from_config(config=config.vision_config)
self._build_image_projection_layers(config)
self.__causal_attention_mask = None
# Initialize weights and apply final processing
self.post_init()
################ Vision Tower ########################
def _build_image_projection_layers(self, config):
image_dim_out = config.vision_config.dim_embed[-1]
dim_projection = config.vision_config.projection_dim
self.image_projection = nn.Parameter(
torch.empty(image_dim_out, dim_projection)
)
self.image_proj_norm = nn.LayerNorm(dim_projection)
image_pos_embed_config = config.vision_config.image_pos_embed
if image_pos_embed_config['type'] == 'learned_abs_2d':
self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
embedding_dim=image_dim_out,
num_pos=image_pos_embed_config['max_pos_embeddings']
)
else:
raise NotImplementedError('Not implemented yet')
self.image_feature_source = config.vision_config.image_feature_source
# temporal embedding
visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding
if visual_temporal_embedding_config['type'] == 'COSINE':
self.visual_temporal_embed = PositionalEmbeddingCosine1D(
embed_dim=image_dim_out,
max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings']
)
else:
raise NotImplementedError('Not implemented yet')
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds):
batch_size, image_token_length = image_features.size()[:-1]
device = image_features.device
image_attention_mask = torch.ones(batch_size, image_token_length, device=device)
if inputs_embeds is None:
return image_features, image_attention_mask
task_prefix_embeds = inputs_embeds
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
# Assurer que les masques d'attention sont de deux dimensions
if len(task_prefix_attention_mask.shape) == 3:
task_prefix_attention_mask = task_prefix_attention_mask.squeeze(1)
# Vérifier la dimension de batch et ajuster si nécessaire
if image_features.size(0) != task_prefix_embeds.size(0):
raise ValueError("Batch sizes of image_features and task_prefix_embeds do not match")
# Ajouter une dimension fictive si les dimensions ne sont pas alignées
if image_features.dim() < task_prefix_embeds.dim():
image_features = image_features.unsqueeze(-1)
elif task_prefix_embeds.dim() < image_features.dim():
task_prefix_embeds = task_prefix_embeds.unsqueeze(-1)
# Assurer que toutes les dimensions, sauf dim=1, sont identiques
if image_features.size(2) != task_prefix_embeds.size(2):
# Ajuster ou signaler une erreur si les dimensions internes ne sont pas compatibles
raise ValueError("Internal dimensions of image_features and task_prefix_embeds do not match")
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1)
return inputs_embeds, attention_mask
def _encode_image(self, pixel_values):
if len(pixel_values.shape) == 4:
batch_size, C, H, W = pixel_values.shape
T = 1
x = self.vision_tower.forward_features_unpool(pixel_values)
else:
# Ajoute une dimension de batch au début si 'pixel_values' n'a que 3 dimensions (C, H, W)
pixel_values = pixel_values.unsqueeze(0) # Ajoute une dimension de batch
batch_size, C, H, W = pixel_values.shape
T = 1
x = self.vision_tower.forward_features_unpool(pixel_values)
if self.image_pos_embed is not None:
x = x.view(batch_size * T, -1, x.shape[-1])
num_tokens = x.shape[-2]
h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5)
assert h * w == num_tokens, 'only support square feature maps for now'
x = x.view(batch_size * T, h, w, x.shape[-1])
pos_embed = self.image_pos_embed(x)
x = x + pos_embed
x = x.view(batch_size, T * h*w, x.shape[-1])
if self.visual_temporal_embed is not None:
visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
x_feat_dict = {}
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
x_feat_dict['last_frame'] = x
new_x = []
for _image_feature_source in self.image_feature_source:
if _image_feature_source not in x_feat_dict:
raise ValueError('invalid image feature source: {}'.format(_image_feature_source))
new_x.append(x_feat_dict[_image_feature_source])
x = torch.cat(new_x, dim=1)
x = x @ self.image_projection
x = self.image_proj_norm(x)
return x
#######################################################
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(FEYNMODEL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
if self.training and self.config._attn_implementation != "eager":
logger.warning_once(
"It is strongly recommended to train FeynModel models with the `eager` attention implementation "
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is not None:
self.model.mode='vlm'
if input_ids is not None:
inputs_embeds = self.get_input_embeddings()(input_ids)
image_features = self._encode_image(pixel_values)
inputs_embeds, causal_attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds )
causal_attention_mask = create_git_attention_mask(tgt=input_ids, memory=image_features,max_length=2048)
causal_attention_mask=causal_attention_mask.to(input_ids.device)
self.__causal_attention_mask=causal_attention_mask
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
if pixel_values is not None:
outputs = self.model(
input_ids=None,
attention_mask=causal_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
causal_attention_mask=causal_attention_mask,
)
else:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
causal_attention_mask=self.__causal_attention_mask,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
logits = logits.float()
loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
num_image_tokens = self.model.image_patch_tokens
shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
# print(f"+-+-+-+-+-+-+++ past_key_values +-+-+++- position_ids {position_ids.size()} ================= ")
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
# which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# print(f"+-+-+-+-+-+-+++ past_key_values +-+-+++- position_ids cmlone ==> {position_ids.size()} ================= ")
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
#print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> first generation step>>>>>>>>>>>>>>>>>>>>>>>>>>>>>><")
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# The clone here is for the same reason as for `position_ids`.
# print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> The clone here is for the same reason as for `position_ids` ==> input_ids input_ids.clone.>>>>>>>>>>>>>>>>>>>>>>>>>>>>>><")
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)}
if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2:
if inputs_embeds is not None and input_ids.size(1)!= 0 :
###################### V ############## add _ for _ = inputs_embeds.shape
batch_size, sequence_length, _ = inputs_embeds.shape
device = inputs_embeds.device
#print(f"1111111 +-+-+-+-+-+-+-+-+-+- sequence_length = inputs_embeds {sequence_length}")
else:
batch_size, sequence_length = position_ids.shape
device = input_ids.device
#print(f"22222222 +-+-+-+-+-+-+-+-+-+- sequence_length = input_ids.shape {sequence_length}")
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
def generate(
self,
input_ids,
pixel_values=None,
max_length=None,
do_sample=True,
temperature=0.7,
**kwargs
):
if pixel_values is not None:
if input_ids is not None:
inputs_embeds = self.get_input_embeddings()(input_ids)
print("pixels")
image_features = self._encode_image(pixel_values)
inputs_embeds, causal_attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds )
causal_attention_mask = create_git_attention_mask(tgt=input_ids, memory=image_features,max_length=max_length)
causal_attention_mask=causal_attention_mask.to(input_ids.device)
self.__causal_attention_mask=causal_attention_mask
self.model.mode='vlm'
result = super().generate(
input_ids=None,
inputs_embeds=inputs_embeds,
max_length=max_length,
do_sample=do_sample,
temperature=temperature,
**kwargs
)
else:
self.model.mode=='llm'
result = super().generate(
input_ids=input_ids,
#inputs_embeds=None,
max_length=max_length,
do_sample=do_sample,
temperature=temperature,
**kwargs
)
self.__causal_attention_mask = None
return result