ColFlor-Demo / modeling_colflor.py
ahmed-masry's picture
Update modeling_colflor.py
1312eac verified
from typing import ClassVar
import torch
from torch import nn
from modeling_florence2 import Florence2ForConditionalGeneration, Florence2VisionLanguageModel
from configuration_florence2 import Florence2Config
class ColFlor(Florence2VisionLanguageModel):
"""
ColFlor model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
"""
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
def __init__(self, config: Florence2Config, use_cache=False):
super().__init__(config=config)
self.dim = 128
self.custom_text_proj = nn.Linear(self.config.text_config.d_model, self.dim)
# Now initialize weights properly
self.custom_text_proj.weight.data.normal_(mean=0.0, std=0.02)
self.custom_text_proj.bias.data.zero_()
self.padding_side = "right"
self.post_init()
def forward(self, *args, **kwargs) -> torch.Tensor:
# Delete output_hidden_states from kwargs
kwargs.pop("output_hidden_states", None)
# Create Full Attention Mask that includes both the image and text
if 'full_attention_mask' in kwargs:
full_attention_mask = kwargs['full_attention_mask']
del kwargs['full_attention_mask']
else:
full_attention_mask = kwargs['attention_mask']
outputs = super().forward(*args,
**kwargs) # (batch_size, sequence_length, hidden_size)
last_hidden_states = outputs['encoder_last_hidden_state'] # (batch_size, sequence_length, hidden_size)
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
proj = proj * full_attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
return proj