Spaces:
Running
on
Zero
Running
on
Zero
from typing import ClassVar | |
import torch | |
from torch import nn | |
from modeling_florence2 import Florence2ForConditionalGeneration, Florence2VisionLanguageModel | |
from configuration_florence2 import Florence2Config | |
class ColFlor(Florence2VisionLanguageModel): | |
""" | |
ColFlor model implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. | |
""" | |
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related | |
def __init__(self, config: Florence2Config, use_cache=False): | |
super().__init__(config=config) | |
self.dim = 128 | |
self.custom_text_proj = nn.Linear(self.config.text_config.d_model, self.dim) | |
# Now initialize weights properly | |
self.custom_text_proj.weight.data.normal_(mean=0.0, std=0.02) | |
self.custom_text_proj.bias.data.zero_() | |
self.padding_side = "right" | |
self.post_init() | |
def forward(self, *args, **kwargs) -> torch.Tensor: | |
# Delete output_hidden_states from kwargs | |
kwargs.pop("output_hidden_states", None) | |
# Create Full Attention Mask that includes both the image and text | |
if 'full_attention_mask' in kwargs: | |
full_attention_mask = kwargs['full_attention_mask'] | |
del kwargs['full_attention_mask'] | |
else: | |
full_attention_mask = kwargs['attention_mask'] | |
outputs = super().forward(*args, | |
**kwargs) # (batch_size, sequence_length, hidden_size) | |
last_hidden_states = outputs['encoder_last_hidden_state'] # (batch_size, sequence_length, hidden_size) | |
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) | |
# L2 normalization | |
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) | |
proj = proj * full_attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) | |
return proj |