|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import ( |
|
ModelOutput, |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
is_flash_attn_2_available, |
|
logging, |
|
replace_return_docstrings, |
|
is_flash_attn_2_available, |
|
is_flash_attn_greater_or_equal_2_10, |
|
) |
|
from transformers.activations import ACT2FN |
|
from transformers.modeling_attn_mask_utils import ( |
|
_prepare_4d_attention_mask, |
|
_prepare_4d_attention_mask_for_sdpa, |
|
_prepare_4d_causal_attention_mask, |
|
_prepare_4d_causal_attention_mask_for_sdpa, |
|
) |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutput, |
|
BaseModelOutputWithPastAndCrossAttentions, |
|
Seq2SeqLMOutput, |
|
Seq2SeqModelOutput, |
|
) |
|
|
|
from transformers.cache_utils import Cache, HybridCache |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
SequenceClassifierOutputWithPast, |
|
TokenClassifierOutput, |
|
) |
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2ForCausalLM,Gemma2DecoderLayer,Gemma2RMSNorm |
|
from .configuration_feynmodel import FeynModelConfig,Florence2VisionConfig |
|
|
|
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM |
|
import json |
|
import math |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import logging |
|
|
|
from transformers.utils import ( |
|
ModelOutput, |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
is_flash_attn_2_available, |
|
logging, |
|
replace_return_docstrings, |
|
is_flash_attn_2_available, |
|
is_flash_attn_greater_or_equal_2_10, |
|
) |
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
from collections import OrderedDict |
|
from einops import rearrange |
|
from timm.models.layers import DropPath, trunc_normal_ |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class MySequential(nn.Sequential): |
|
def forward(self, *inputs): |
|
for module in self._modules.values(): |
|
if type(inputs) == tuple: |
|
inputs = module(*inputs) |
|
else: |
|
inputs = module(inputs) |
|
return inputs |
|
|
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, norm, fn, drop_path=None): |
|
super().__init__() |
|
self.norm = norm |
|
self.fn = fn |
|
self.drop_path = drop_path |
|
|
|
def forward(self, x, *args, **kwargs): |
|
shortcut = x |
|
if self.norm != None: |
|
x, size = self.fn(self.norm(x), *args, **kwargs) |
|
else: |
|
x, size = self.fn(x, *args, **kwargs) |
|
|
|
if self.drop_path: |
|
x = self.drop_path(x) |
|
|
|
x = shortcut + x |
|
|
|
return x, size |
|
|
|
|
|
class Mlp(nn.Module): |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_layer=nn.GELU, |
|
): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
self.net = nn.Sequential(OrderedDict([ |
|
("fc1", nn.Linear(in_features, hidden_features)), |
|
("act", act_layer()), |
|
("fc2", nn.Linear(hidden_features, out_features)) |
|
])) |
|
|
|
def forward(self, x, size): |
|
return self.net(x), size |
|
|
|
|
|
class DepthWiseConv2d(nn.Module): |
|
def __init__( |
|
self, |
|
dim_in, |
|
kernel_size, |
|
padding, |
|
stride, |
|
bias=True, |
|
): |
|
super().__init__() |
|
self.dw = nn.Conv2d( |
|
dim_in, dim_in, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
groups=dim_in, |
|
stride=stride, |
|
bias=bias |
|
) |
|
|
|
def forward(self, x, size): |
|
B, N, C = x.shape |
|
H, W = size |
|
assert N == H * W |
|
|
|
x = self.dw(x.transpose(1, 2).view(B, C, H, W)) |
|
size = (x.size(-2), x.size(-1)) |
|
x = x.flatten(2).transpose(1, 2) |
|
return x, size |
|
|
|
|
|
class ConvEmbed(nn.Module): |
|
""" Image to Patch Embedding |
|
""" |
|
|
|
def __init__( |
|
self, |
|
patch_size=7, |
|
in_chans=3, |
|
embed_dim=64, |
|
stride=4, |
|
padding=2, |
|
norm_layer=None, |
|
pre_norm=True |
|
): |
|
super().__init__() |
|
self.patch_size = patch_size |
|
|
|
self.proj = nn.Conv2d( |
|
in_chans, embed_dim, |
|
kernel_size=patch_size, |
|
stride=stride, |
|
padding=padding |
|
) |
|
|
|
dim_norm = in_chans if pre_norm else embed_dim |
|
self.norm = norm_layer(dim_norm) if norm_layer else None |
|
|
|
self.pre_norm = pre_norm |
|
|
|
def forward(self, x, size): |
|
H, W = size |
|
if len(x.size()) == 3: |
|
if self.norm and self.pre_norm: |
|
x = self.norm(x) |
|
x = rearrange( |
|
x, 'b (h w) c -> b c h w', |
|
h=H, w=W |
|
) |
|
|
|
x = self.proj(x) |
|
|
|
_, _, H, W = x.shape |
|
x = rearrange(x, 'b c h w -> b (h w) c') |
|
if self.norm and not self.pre_norm: |
|
x = self.norm(x) |
|
|
|
return x, (H, W) |
|
|
|
|
|
class ChannelAttention(nn.Module): |
|
|
|
def __init__(self, dim, groups=8, qkv_bias=True): |
|
super().__init__() |
|
|
|
self.groups = groups |
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
self.proj = nn.Linear(dim, dim) |
|
|
|
def forward(self, x, size): |
|
B, N, C = x.shape |
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4) |
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
q = q * (float(N) ** -0.5) |
|
attention = q.transpose(-1, -2) @ k |
|
attention = attention.softmax(dim=-1) |
|
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) |
|
x = x.transpose(1, 2).reshape(B, N, C) |
|
x = self.proj(x) |
|
return x, size |
|
|
|
|
|
class ChannelBlock(nn.Module): |
|
|
|
def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True, |
|
drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, |
|
conv_at_attn=True, conv_at_ffn=True): |
|
super().__init__() |
|
|
|
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
|
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None |
|
self.channel_attn = PreNorm( |
|
norm_layer(dim), |
|
ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), |
|
drop_path |
|
) |
|
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None |
|
self.ffn = PreNorm( |
|
norm_layer(dim), |
|
Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), |
|
drop_path |
|
) |
|
|
|
def forward(self, x, size): |
|
if self.conv1: |
|
x, size = self.conv1(x, size) |
|
x, size = self.channel_attn(x, size) |
|
|
|
if self.conv2: |
|
x, size = self.conv2(x, size) |
|
x, size = self.ffn(x, size) |
|
|
|
return x, size |
|
|
|
|
|
def window_partition(x, window_size: int): |
|
B, H, W, C = x.shape |
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
|
return windows |
|
|
|
|
|
def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): |
|
B = batch_size |
|
|
|
|
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
|
return x |
|
|
|
|
|
class WindowAttention(nn.Module): |
|
def __init__(self, dim, num_heads, window_size, qkv_bias=True): |
|
|
|
super().__init__() |
|
self.dim = dim |
|
self.window_size = window_size |
|
self.num_heads = num_heads |
|
head_dim = dim // num_heads |
|
self.scale = float(head_dim) ** -0.5 |
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
self.proj = nn.Linear(dim, dim) |
|
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
def forward(self, x, size): |
|
|
|
H, W = size |
|
B, L, C = x.shape |
|
assert L == H * W, "input feature has wrong size" |
|
|
|
x = x.view(B, H, W, C) |
|
|
|
pad_l = pad_t = 0 |
|
pad_r = (self.window_size - W % self.window_size) % self.window_size |
|
pad_b = (self.window_size - H % self.window_size) % self.window_size |
|
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) |
|
_, Hp, Wp, _ = x.shape |
|
|
|
x = window_partition(x, self.window_size) |
|
x = x.view(-1, self.window_size * self.window_size, C) |
|
|
|
|
|
|
|
|
|
B_, N, C = x.shape |
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
q = q * self.scale |
|
attn = (q @ k.transpose(-2, -1)) |
|
attn = self.softmax(attn) |
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
|
x = self.proj(x) |
|
|
|
|
|
x = x.view( |
|
-1, self.window_size, self.window_size, C |
|
) |
|
x = window_reverse(x, B, self.window_size, Hp, Wp) |
|
|
|
if pad_r > 0 or pad_b > 0: |
|
x = x[:, :H, :W, :].contiguous() |
|
|
|
x = x.view(B, H * W, C) |
|
|
|
return x, size |
|
|
|
|
|
class SpatialBlock(nn.Module): |
|
|
|
def __init__(self, dim, num_heads, window_size, |
|
mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True): |
|
super().__init__() |
|
|
|
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
|
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None |
|
self.window_attn = PreNorm( |
|
norm_layer(dim), |
|
WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), |
|
drop_path |
|
) |
|
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None |
|
self.ffn = PreNorm( |
|
norm_layer(dim), |
|
Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), |
|
drop_path |
|
) |
|
|
|
def forward(self, x, size): |
|
if self.conv1: |
|
x, size = self.conv1(x, size) |
|
x, size = self.window_attn(x, size) |
|
|
|
if self.conv2: |
|
x, size = self.conv2(x, size) |
|
x, size = self.ffn(x, size) |
|
return x, size |
|
|
|
|
|
class DaViT(nn.Module): |
|
""" DaViT: Dual-Attention Transformer |
|
|
|
Args: |
|
in_chans (int): Number of input image channels. Default: 3. |
|
num_classes (int): Number of classes for classification head. Default: 1000. |
|
patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2). |
|
patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2). |
|
patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0). |
|
patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False). |
|
embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256). |
|
num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16). |
|
num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16). |
|
window_size (int): Window size. Default: 7. |
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. |
|
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. |
|
drop_path_rate (float): Stochastic depth rate. Default: 0.1. |
|
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. |
|
enable_checkpoint (bool): If True, enable checkpointing. Default: False. |
|
conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True. |
|
conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_chans=3, |
|
num_classes=1000, |
|
depths=(1, 1, 3, 1), |
|
patch_size=(7, 2, 2, 2), |
|
patch_stride=(4, 2, 2, 2), |
|
patch_padding=(3, 0, 0, 0), |
|
patch_prenorm=(False, False, False, False), |
|
embed_dims=(64, 128, 192, 256), |
|
num_heads=(3, 6, 12, 24), |
|
num_groups=(3, 6, 12, 24), |
|
window_size=7, |
|
mlp_ratio=4., |
|
qkv_bias=True, |
|
drop_path_rate=0.1, |
|
norm_layer=nn.LayerNorm, |
|
enable_checkpoint=False, |
|
conv_at_attn=True, |
|
conv_at_ffn=True, |
|
): |
|
super().__init__() |
|
|
|
self.num_classes = num_classes |
|
self.embed_dims = embed_dims |
|
self.num_heads = num_heads |
|
self.num_groups = num_groups |
|
self.num_stages = len(self.embed_dims) |
|
self.enable_checkpoint = enable_checkpoint |
|
assert self.num_stages == len(self.num_heads) == len(self.num_groups) |
|
|
|
num_stages = len(embed_dims) |
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)] |
|
|
|
depth_offset = 0 |
|
convs = [] |
|
blocks = [] |
|
for i in range(num_stages): |
|
conv_embed = ConvEmbed( |
|
patch_size=patch_size[i], |
|
stride=patch_stride[i], |
|
padding=patch_padding[i], |
|
in_chans=in_chans if i == 0 else self.embed_dims[i - 1], |
|
embed_dim=self.embed_dims[i], |
|
norm_layer=norm_layer, |
|
pre_norm=patch_prenorm[i] |
|
) |
|
convs.append(conv_embed) |
|
|
|
block = MySequential( |
|
*[ |
|
MySequential(OrderedDict([ |
|
( |
|
'spatial_block', SpatialBlock( |
|
embed_dims[i], |
|
num_heads[i], |
|
window_size, |
|
drop_path_rate=dpr[depth_offset+j*2], |
|
qkv_bias=qkv_bias, |
|
mlp_ratio=mlp_ratio, |
|
conv_at_attn=conv_at_attn, |
|
conv_at_ffn=conv_at_ffn, |
|
) |
|
), |
|
( |
|
'channel_block', ChannelBlock( |
|
embed_dims[i], |
|
num_groups[i], |
|
drop_path_rate=dpr[depth_offset+j*2+1], |
|
qkv_bias=qkv_bias, |
|
mlp_ratio=mlp_ratio, |
|
conv_at_attn=conv_at_attn, |
|
conv_at_ffn=conv_at_ffn, |
|
) |
|
) |
|
])) for j in range(depths[i]) |
|
] |
|
) |
|
blocks.append(block) |
|
depth_offset += depths[i]*2 |
|
|
|
self.convs = nn.ModuleList(convs) |
|
self.blocks = nn.ModuleList(blocks) |
|
|
|
self.norms = norm_layer(self.embed_dims[-1]) |
|
self.avgpool = nn.AdaptiveAvgPool1d(1) |
|
self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() |
|
|
|
self.apply(self._init_weights) |
|
|
|
@property |
|
def dim_out(self): |
|
return self.embed_dims[-1] |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.Conv2d): |
|
nn.init.normal_(m.weight, std=0.02) |
|
for name, _ in m.named_parameters(): |
|
if name in ['bias']: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.weight, 1.0) |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.BatchNorm2d): |
|
nn.init.constant_(m.weight, 1.0) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward_features_unpool(self, x): |
|
""" |
|
forward until avg pooling |
|
Args: |
|
x (_type_): input image tensor |
|
""" |
|
input_size = (x.size(2), x.size(3)) |
|
for conv, block in zip(self.convs, self.blocks): |
|
x, input_size = conv(x, input_size) |
|
if self.enable_checkpoint: |
|
x, input_size = checkpoint.checkpoint(block, x, input_size) |
|
else: |
|
x, input_size = block(x, input_size) |
|
return x |
|
|
|
def forward_features(self, x): |
|
x = self.forward_features_unpool(x) |
|
|
|
|
|
x = self.avgpool(x.transpose(1, 2)) |
|
|
|
x = torch.flatten(x, 1) |
|
x = self.norms(x) |
|
|
|
return x |
|
|
|
def forward(self, x): |
|
x = self.forward_features(x) |
|
x = self.head(x) |
|
return x |
|
|
|
@classmethod |
|
def from_config(cls, config): |
|
return cls( |
|
depths=config.depths, |
|
embed_dims=config.dim_embed, |
|
num_heads=config.num_heads, |
|
num_groups=config.num_groups, |
|
patch_size=config.patch_size, |
|
patch_stride=config.patch_stride, |
|
patch_padding=config.patch_padding, |
|
patch_prenorm=config.patch_prenorm, |
|
drop_path_rate=config.drop_path_rate, |
|
window_size=config.window_size, |
|
) |
|
|
|
|
|
|
|
|
|
_CONFIG_FOR_DOC = "FeynModelConfig" |
|
|
|
FEYNMODEL_START_DOCSTRING = r""" |
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Parameters: |
|
config ([`FeynModelConfig`]): |
|
Model configuration class with all the parameters of the model. Initializing with a config file does not |
|
load the weights associated with the model, only the configuration. Check out the |
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
FEYNMODEL_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see |
|
`past_key_values`). |
|
|
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
|
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
|
information on the default strategy. |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.n_positions - 1]`. |
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` |
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
|
|
|
Two formats are allowed: |
|
- a [`~cache_utils.Cache`] instance; |
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy |
|
cache format. |
|
|
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the |
|
legacy cache format will be returned. |
|
|
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't |
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` |
|
of shape `(batch_size, sequence_length)`. |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
|
model's internal embedding lookup matrix. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
|
`past_key_values`). |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
|
the complete sequence length. |
|
""" |
|
|
|
|
|
def _prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask: torch.Tensor, |
|
sequence_length: int, |
|
target_length: int, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
min_dtype: float, |
|
cache_position: torch.Tensor, |
|
batch_size: int, |
|
): |
|
|
|
|
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
|
|
|
|
|
|
causal_mask = attention_mask[:, :, -sequence_length:, :] |
|
|
|
|
|
else: |
|
|
|
|
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=torch.float32, device=device) |
|
|
|
if sequence_length != 1: |
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
|
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
|
|
if attention_mask is not None: |
|
|
|
causal_mask = causal_mask.clone() |
|
mask_length = attention_mask.shape[-1] |
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
|
padding_mask = padding_mask == 0 |
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
|
padding_mask, min_dtype |
|
) |
|
|
|
|
|
return causal_mask |
|
|
|
class LearnedAbsolutePositionEmbedding2D(nn.Module): |
|
""" |
|
This module learns positional embeddings up to a fixed maximum size. |
|
""" |
|
|
|
def __init__(self, embedding_dim=256, num_pos=50): |
|
super().__init__() |
|
self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) |
|
self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2)) |
|
|
|
def forward(self, pixel_values): |
|
""" |
|
pixel_values: (batch_size, height, width, num_channels) |
|
returns: (batch_size, height, width, embedding_dim * 2) |
|
""" |
|
if len(pixel_values.shape) != 4: |
|
raise ValueError('pixel_values must be a 4D tensor') |
|
height, width = pixel_values.shape[1:3] |
|
width_values = torch.arange(width, device=pixel_values.device) |
|
height_values = torch.arange(height, device=pixel_values.device) |
|
x_emb = self.column_embeddings(width_values) |
|
y_emb = self.row_embeddings(height_values) |
|
|
|
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) |
|
|
|
pos = pos.permute(2, 0, 1) |
|
pos = pos.unsqueeze(0) |
|
|
|
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) |
|
|
|
pos = pos.permute(0, 2, 3, 1) |
|
return pos |
|
|
|
class PositionalEmbeddingCosine1D(nn.Module): |
|
""" |
|
This class implements a very simple positional encoding. It follows closely |
|
the encoder from the link below: |
|
https://pytorch.org/tutorials/beginner/translation_transformer.html |
|
Args: |
|
embed_dim: The dimension of the embeddings. |
|
dropout_prob: The dropout probability. |
|
max_seq_len: The maximum length to precompute the positional encodings. |
|
""" |
|
def __init__( |
|
self, |
|
embed_dim: int = 512, |
|
max_seq_len: int = 1024) -> None: |
|
super(PositionalEmbeddingCosine1D, self).__init__() |
|
self.embed_dim = embed_dim |
|
self.max_seq_len = max_seq_len |
|
|
|
factor = math.log(10000) |
|
denominator = torch.exp( |
|
-factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim) |
|
|
|
|
|
frequencies = \ |
|
torch.arange(0, self.max_seq_len) \ |
|
.reshape(self.max_seq_len, 1) * denominator |
|
pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) |
|
|
|
pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) |
|
pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) |
|
|
|
self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) |
|
|
|
def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
seq_embeds: The sequence embeddings in order. Allowed size: |
|
1. [T, D], where T is the length of the sequence, and D is the |
|
frame embedding dimension. |
|
2. [B, T, D], where B is the batch size and T and D are the |
|
same as above. |
|
Returns a tensor of with the same dimensions as the input: i.e., |
|
[1, T, D] or [T, D]. |
|
""" |
|
shape_len = len(seq_embeds.shape) |
|
assert 2 <= shape_len <= 3 |
|
len_seq = seq_embeds.size(-2) |
|
assert len_seq <= self.max_seq_len |
|
pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] |
|
|
|
if shape_len == 3: |
|
pos_embeds = pos_embeds.view( |
|
(1, pos_embeds.size(0), pos_embeds.size(1))) |
|
return pos_embeds |
|
|
|
|
|
class LearnedAbsolutePositionEmbedding1D(nn.Module): |
|
""" |
|
Learnable absolute positional embeddings for 1D sequences. |
|
Args: |
|
embed_dim: The dimension of the embeddings. |
|
max_seq_len: The maximum length to precompute the positional encodings. |
|
""" |
|
def __init__( |
|
self, |
|
embedding_dim: int = 512, |
|
num_pos: int = 1024) -> None: |
|
super(LearnedAbsolutePositionEmbedding1D, self).__init__() |
|
self.embeddings = nn.Embedding(num_pos, embedding_dim) |
|
self.num_pos = num_pos |
|
|
|
def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
seq_embeds: The sequence embeddings in order. Allowed size: |
|
1. [T, D], where T is the length of the sequence, and D is the |
|
frame embedding dimension. |
|
2. [B, T, D], where B is the batch size and T and D are the |
|
same as above. |
|
Returns a tensor of with the same dimensions as the input: i.e., |
|
[1, T, D] or [T, D]. |
|
""" |
|
shape_len = len(seq_embeds.shape) |
|
assert 2 <= shape_len <= 3 |
|
len_seq = seq_embeds.size(-2) |
|
assert len_seq <= self.num_pos |
|
|
|
pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device)) |
|
|
|
if shape_len == 3: |
|
pos_embeds = pos_embeds.view( |
|
(1, pos_embeds.size(0), pos_embeds.size(1))) |
|
return pos_embeds |
|
|
|
def create_git_attention_mask( |
|
tgt: torch.Tensor, |
|
memory: torch.Tensor, |
|
max_length: int |
|
) -> torch.Tensor: |
|
|
|
batch_size = tgt.size(0) |
|
num_tgt = tgt.shape[1] |
|
num_memory = memory.shape[1] |
|
total_length = num_memory + num_tgt |
|
|
|
|
|
top_left = torch.zeros((num_memory, num_memory)) |
|
top_right = torch.full((num_memory, num_tgt), float(-3.4028e+38)) |
|
|
|
|
|
bottom_left = torch.zeros((num_tgt, num_memory)) |
|
|
|
|
|
bottom_right = torch.tril(torch.ones(num_tgt, num_tgt)) |
|
|
|
|
|
bottom_right = bottom_right.masked_fill(bottom_right == 0, float(-3.4028e+38)) |
|
bottom_right = bottom_right.masked_fill(bottom_right == 1, float(0)) |
|
|
|
|
|
left = torch.cat((top_left, bottom_left), dim=0) |
|
right = torch.cat((top_right, bottom_right), dim=0) |
|
|
|
|
|
full_attention_mask = torch.cat((left, right), dim=1) |
|
|
|
|
|
padding = torch.full((total_length, max_length - total_length), float(-3.4028e+38)) |
|
full_attention_mask = torch.cat((full_attention_mask, padding), dim=1) |
|
|
|
|
|
full_attention_mask = full_attention_mask[None, None, :, :] |
|
|
|
|
|
full_attention_mask = full_attention_mask.expand(batch_size, 1, full_attention_mask.size(-2), full_attention_mask.size(-1)) |
|
|
|
return full_attention_mask |
|
|
|
def get_position_ids_from_binary_attention_mask(mask): |
|
""" |
|
Extract position IDs from a binary attention mask. |
|
|
|
Args: |
|
mask (torch.Tensor): The attention mask tensor of shape (1, 1, seq_len, seq_len), |
|
where 1 indicates allowed attention and 0 indicates blocked attention. |
|
|
|
Returns: |
|
list: A list of lists where each sublist contains the allowed position IDs for each query position. |
|
""" |
|
|
|
_, _, seq_len, _ = mask.shape |
|
|
|
|
|
position_ids = torch.arange(seq_len, dtype=torch.long, device=mask.device) |
|
|
|
|
|
position_ids = position_ids.unsqueeze(0) |
|
|
|
return position_ids |
|
|
|
def ensure_tensor(variable): |
|
|
|
if isinstance(variable, torch.Tensor): |
|
|
|
return variable |
|
else: |
|
|
|
try: |
|
|
|
tensor = torch.tensor(variable) |
|
|
|
return tensor |
|
except Exception as e: |
|
print(f"Error converting to tensor: {e}") |
|
raise |
|
|
|
@add_start_docstrings( |
|
"The bare Model outputting raw hidden-states without any specific head on top.", |
|
FEYNMODEL_START_DOCSTRING, |
|
) |
|
class FeynModel(Gemma2Model): |
|
""" |
|
Transformer decoder consisting of *config.num_hidden_layers* layers. |
|
Each layer is a [`FeynModelDecoderLayer`] + ['LoraLayer'] for *proj* moduls |
|
NB : LoraLayers will be added and activatd on proj modules onpy if pixel_values is not None |
|
|
|
Args: |
|
config: FeynModelConfig |
|
""" |
|
|
|
def __init__(self, config: FeynModelConfig): |
|
super().__init__(config) |
|
|
|
self.mode='llm' |
|
''' |
|
self.image_patch_tokens = int( |
|
(config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1 |
|
) |
|
|
|
if config.num_image_with_embedding is not None: |
|
self.image_patch_tokens *= config.num_image_with_embedding |
|
''' |
|
self.image_patch_tokens = 577 |
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.embed_tokens = value |
|
|
|
|
|
|
|
|
|
@add_start_docstrings_to_model_forward(FEYNMODEL_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
causal_attention_mask: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
|
|
|
|
|
|
|
|
|
|
if cache_position is None: |
|
batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0) |
|
cache_position = torch.zeros((batch_size,), dtype=torch.long, device=input_ids.device if input_ids is not None else inputs_embeds.device) |
|
|
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError( |
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
|
) |
|
|
|
if self.gradient_checkpointing and self.training and use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
|
) |
|
use_cache = False |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
causal_mask = self._update_causal_mask( |
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
|
) |
|
else: |
|
causal_mask = ensure_tensor(causal_attention_mask) |
|
position_ids = get_position_ids_from_binary_attention_mask(attention_mask) |
|
|
|
|
|
|
|
if cache_position is None: |
|
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) |
|
|
|
if position_ids is None : |
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
if not isinstance(position_ids, torch.Tensor): |
|
|
|
position_ids = torch.tensor(position_ids, dtype=torch.long, device=inputs_embeds.device) |
|
|
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
|
|
|
|
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) |
|
hidden_states = hidden_states * normalizer |
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
|
|
for decoder_layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
causal_mask, |
|
position_ids, |
|
past_key_values, |
|
output_attentions, |
|
use_cache, |
|
cache_position, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = past_key_values if use_cache else None |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
|
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask: torch.Tensor, |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_key_values: Cache, |
|
output_attentions: bool, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
return attention_mask |
|
|
|
dtype, device = input_tensor.dtype, input_tensor.device |
|
min_dtype = torch.finfo(dtype).min |
|
sequence_length = input_tensor.shape[1] |
|
if isinstance(past_key_values, HybridCache): |
|
target_length = past_key_values.get_max_length() |
|
else: |
|
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] |
|
|
|
|
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask, |
|
sequence_length=sequence_length, |
|
target_length=target_length, |
|
dtype=dtype, |
|
device=device, |
|
min_dtype=min_dtype, |
|
cache_position=cache_position, |
|
batch_size=input_tensor.shape[0], |
|
) |
|
|
|
return causal_mask |
|
|
|
|
|
|
|
class FeynModelForCausalLM(Gemma2ForCausalLM): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
config_class = FeynModelConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
config.vision_config=Florence2VisionConfig.from_dict(config.vision_config) |
|
self.model = FeynModel(config) |
|
|
|
|
|
self.vision_tower = DaViT.from_config(config=config.vision_config) |
|
self._build_image_projection_layers(config) |
|
|
|
self.__causal_attention_mask = None |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _build_image_projection_layers(self, config): |
|
image_dim_out = config.vision_config.dim_embed[-1] |
|
dim_projection = config.vision_config.projection_dim |
|
self.image_projection = nn.Parameter( |
|
torch.empty(image_dim_out, dim_projection) |
|
) |
|
self.image_proj_norm = nn.LayerNorm(dim_projection) |
|
image_pos_embed_config = config.vision_config.image_pos_embed |
|
if image_pos_embed_config['type'] == 'learned_abs_2d': |
|
self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( |
|
embedding_dim=image_dim_out, |
|
num_pos=image_pos_embed_config['max_pos_embeddings'] |
|
) |
|
else: |
|
raise NotImplementedError('Not implemented yet') |
|
|
|
self.image_feature_source = config.vision_config.image_feature_source |
|
|
|
|
|
visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding |
|
if visual_temporal_embedding_config['type'] == 'COSINE': |
|
self.visual_temporal_embed = PositionalEmbeddingCosine1D( |
|
embed_dim=image_dim_out, |
|
max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings'] |
|
) |
|
else: |
|
raise NotImplementedError('Not implemented yet') |
|
|
|
|
|
|
|
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds): |
|
batch_size, image_token_length = image_features.size()[:-1] |
|
device = image_features.device |
|
image_attention_mask = torch.ones(batch_size, image_token_length, device=device) |
|
|
|
if inputs_embeds is None: |
|
return image_features, image_attention_mask |
|
|
|
task_prefix_embeds = inputs_embeds |
|
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device) |
|
|
|
|
|
if len(task_prefix_attention_mask.shape) == 3: |
|
task_prefix_attention_mask = task_prefix_attention_mask.squeeze(1) |
|
|
|
|
|
if image_features.size(0) != task_prefix_embeds.size(0): |
|
raise ValueError("Batch sizes of image_features and task_prefix_embeds do not match") |
|
|
|
|
|
if image_features.dim() < task_prefix_embeds.dim(): |
|
image_features = image_features.unsqueeze(-1) |
|
elif task_prefix_embeds.dim() < image_features.dim(): |
|
task_prefix_embeds = task_prefix_embeds.unsqueeze(-1) |
|
|
|
|
|
if image_features.size(2) != task_prefix_embeds.size(2): |
|
|
|
raise ValueError("Internal dimensions of image_features and task_prefix_embeds do not match") |
|
|
|
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1) |
|
attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1) |
|
|
|
return inputs_embeds, attention_mask |
|
|
|
def _encode_image(self, pixel_values): |
|
if len(pixel_values.shape) == 4: |
|
batch_size, C, H, W = pixel_values.shape |
|
T = 1 |
|
x = self.vision_tower.forward_features_unpool(pixel_values) |
|
else: |
|
|
|
pixel_values = pixel_values.unsqueeze(0) |
|
batch_size, C, H, W = pixel_values.shape |
|
T = 1 |
|
x = self.vision_tower.forward_features_unpool(pixel_values) |
|
|
|
if self.image_pos_embed is not None: |
|
x = x.view(batch_size * T, -1, x.shape[-1]) |
|
num_tokens = x.shape[-2] |
|
h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) |
|
assert h * w == num_tokens, 'only support square feature maps for now' |
|
x = x.view(batch_size * T, h, w, x.shape[-1]) |
|
pos_embed = self.image_pos_embed(x) |
|
x = x + pos_embed |
|
x = x.view(batch_size, T * h*w, x.shape[-1]) |
|
|
|
if self.visual_temporal_embed is not None: |
|
visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) |
|
x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) |
|
|
|
x_feat_dict = {} |
|
|
|
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) |
|
x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x |
|
|
|
temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) |
|
x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x |
|
|
|
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] |
|
x_feat_dict['last_frame'] = x |
|
|
|
new_x = [] |
|
for _image_feature_source in self.image_feature_source: |
|
if _image_feature_source not in x_feat_dict: |
|
raise ValueError('invalid image feature source: {}'.format(_image_feature_source)) |
|
new_x.append(x_feat_dict[_image_feature_source]) |
|
|
|
x = torch.cat(new_x, dim=1) |
|
|
|
x = x @ self.image_projection |
|
x = self.image_proj_norm(x) |
|
|
|
return x |
|
|
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
return self.model |
|
|
|
@add_start_docstrings_to_model_forward(FEYNMODEL_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
r""" |
|
Args: |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, GemmaForCausalLM |
|
|
|
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") |
|
|
|
>>> prompt = "What is your favorite condiment?" |
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"What is your favorite condiment?" |
|
```""" |
|
|
|
|
|
if self.training and self.config._attn_implementation != "eager": |
|
logger.warning_once( |
|
"It is strongly recommended to train FeynModel models with the `eager` attention implementation " |
|
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`." |
|
) |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if pixel_values is not None: |
|
self.model.mode='vlm' |
|
|
|
if input_ids is not None: |
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
image_features = self._encode_image(pixel_values) |
|
inputs_embeds, causal_attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds ) |
|
causal_attention_mask = create_git_attention_mask(tgt=input_ids, memory=image_features,max_length=8192) |
|
causal_attention_mask=causal_attention_mask.to(input_ids.device) |
|
self.__causal_attention_mask=causal_attention_mask |
|
|
|
|
|
if pixel_values is not None: |
|
outputs = self.model( |
|
input_ids=None, |
|
attention_mask=causal_attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
causal_attention_mask=causal_attention_mask, |
|
) |
|
else: |
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
causal_attention_mask=self.__causal_attention_mask, |
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
|
|
if self.config.final_logit_softcapping is not None: |
|
logits = logits / self.config.final_logit_softcapping |
|
logits = torch.tanh(logits) |
|
logits = logits * self.config.final_logit_softcapping |
|
|
|
|
|
logits = logits.float() |
|
loss = None |
|
if labels is not None: |
|
|
|
num_image_tokens = self.model.image_patch_tokens |
|
shifted_logits = logits[:, num_image_tokens:-1, :].contiguous() |
|
labels = labels[:, 1:].contiguous() |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
if not return_dict: |
|
|
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past_key_values=None, |
|
attention_mask=None, |
|
inputs_embeds=None, |
|
cache_position=None, |
|
position_ids=None, |
|
use_cache=True, |
|
**kwargs, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
if past_key_values is not None: |
|
if inputs_embeds is not None: |
|
input_ids = input_ids[:, -cache_position.shape[0] :] |
|
elif input_ids.shape[1] != cache_position.shape[0]: |
|
input_ids = input_ids[:, cache_position] |
|
|
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_key_values: |
|
|
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
|
|
|
|
|
|
|
|
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format) |
|
|
|
|
|
|
|
if inputs_embeds is not None and cache_position[0] == 0: |
|
|
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
|
|
|
|
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} |
|
|
|
if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2: |
|
if inputs_embeds is not None and input_ids.size(1)!= 0 : |
|
|
|
batch_size, sequence_length, _ = inputs_embeds.shape |
|
device = inputs_embeds.device |
|
|
|
else: |
|
batch_size, sequence_length = position_ids.shape |
|
device = input_ids.device |
|
|
|
|
|
|
|
|
|
if hasattr(self.lm_head, 'weight'): |
|
|
|
if isinstance(self.lm_head.weight, torch.Tensor): |
|
dtype = self.lm_head.weight.dtype |
|
elif callable(self.lm_head.weight): |
|
dtype = self.lm_head.weight().dtype |
|
else: |
|
raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}") |
|
|
|
|
|
|
|
if isinstance(self.lm_head, torch.ao.nn.quantized.dynamic.Linear): |
|
|
|
weight, bias = self.lm_head._weight_bias() |
|
dtype = weight.dtype |
|
else: |
|
dtype = self.lm_head.weight.dtype |
|
|
|
|
|
if torch.is_floating_point(torch.empty(0, dtype=dtype)): |
|
|
|
min_dtype = torch.finfo(torch.float32).min |
|
else: |
|
min_dtype = torch.iinfo(dtype).min |
|
|
|
|
|
|
|
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask, |
|
sequence_length=sequence_length, |
|
target_length=past_key_values.get_max_length(), |
|
dtype=dtype, |
|
device=device, |
|
min_dtype=min_dtype, |
|
cache_position=cache_position, |
|
batch_size=batch_size, |
|
) |
|
|
|
|
|
model_inputs.update( |
|
{ |
|
"position_ids": position_ids, |
|
"cache_position": cache_position, |
|
"past_key_values": past_key_values, |
|
"use_cache": use_cache, |
|
"attention_mask": attention_mask, |
|
} |
|
) |
|
return model_inputs |
|
|
|
def generate( |
|
self, |
|
input_ids, |
|
pixel_values=None, |
|
max_length=None, |
|
do_sample=True, |
|
temperature=0.7, |
|
**kwargs |
|
): |
|
|
|
|
|
if pixel_values is not None: |
|
if input_ids is not None: |
|
|
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
|
|
image_features = self._encode_image(pixel_values) |
|
inputs_embeds, causal_attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds ) |
|
causal_attention_mask = create_git_attention_mask(tgt=input_ids, memory=image_features,max_length=max_length) |
|
causal_attention_mask=causal_attention_mask.to(input_ids.device) |
|
self.__causal_attention_mask=causal_attention_mask |
|
self.model.mode='vlm' |
|
result = super().generate( |
|
input_ids=None, |
|
inputs_embeds=inputs_embeds, |
|
max_length=max_length, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
**kwargs |
|
) |
|
|
|
else: |
|
|
|
self.model.mode=='llm' |
|
result = super().generate( |
|
input_ids=input_ids, |
|
|
|
max_length=max_length, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
**kwargs |
|
) |
|
self.__causal_attention_mask = None |
|
|
|
return result |
|
|
|
|