File size: 5,114 Bytes
c39b2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from typing import ClassVar, List, Optional
from typing import Any, List, Optional, Tuple, Union
import torch
from torch import nn
from .modeling_internvl_chat import InternVLChatModel, InternVLChatConfig
import math
class ColInternVL2(InternVLChatModel):
"""
ColInternVL2 model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
"""
# main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
def __init__(self, config: InternVLChatConfig):
super().__init__(config=config)
self.dim = 128
self.custom_text_proj = nn.Linear(self.language_model.model.config.hidden_size, self.dim ) #, bias=False)
self.padding_side = "left"
self.img_context_token_id = 151648
# self.post_init()
self.init_linear()
def init_linear(self):
print(self.language_model.model.embed_tokens.weight)
stdv = 1. / math.sqrt(self.custom_text_proj.weight.size(1))
self.custom_text_proj.weight.data = self.custom_text_proj.weight.data.uniform_(-stdv, stdv)
if self.custom_text_proj.bias is not None:
self.custom_text_proj.bias.data = self.custom_text_proj.bias.data.uniform_(-stdv, stdv)
def forward(
self,
pixel_values: torch.FloatTensor = None,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_flags: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
statistics: Optional[torch.LongTensor] = None,
loss_weight: Optional[List] = None,
loss_reduction_all_gather: Optional[bool] = False,
**kwargs
) -> torch.Tensor:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
B, N, C = input_embeds.shape
if pixel_values is not None:
pixel_values = pixel_values.type(self.vision_model.embeddings.patch_embedding.weight.dtype)
vit_embeds = self.extract_feature(pixel_values)
# image_flags = image_flags.squeeze(-1)
# vit_embeds = vit_embeds[image_flags == 1]
vit_batch_size = pixel_values.shape[0]
input_embeds = input_embeds.reshape(B * N, C)
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
if statistics is not None:
num_samples, num_padding_tokens, num_padding_images = statistics.tolist()
self.num_samples += num_samples
print(f'total_samples={self.num_samples}, {num_samples=}, {num_padding_tokens=}, {num_padding_images=}')
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.img_context_token_id)
try:
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
ignore_flag = False
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, C)
print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
f'vit_embeds.shape={vit_embeds.shape}')
n_token = selected.sum()
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
ignore_flag = True
input_embeds = input_embeds.reshape(B, N, C)
outputs = self.language_model.model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
last_hidden_states = outputs[0].type(self.custom_text_proj.weight.dtype)
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
proj = proj * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
return proj
@property
def get_patch_size(self) -> int:
return self.visual.config.patch_size
@property
def spatial_merge_size(self) -> int:
return self.visual.config.spatial_merge_size
|