File size: 693 Bytes
fdd38ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import VisionTextDualEncoderConfig

class VTDEConfig(VisionTextDualEncoderConfig):
    model_type = "vtde"
    def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, 
    text_pooling_mode='mean',
    vision_pooling_mode='max',
    **kwargs):
        """
        pooling_mode in ['mean', 'max', 'cls']
        https://arxiv.org/pdf/2210.09996.pdf
        https://github.com/kahnchana/clippy/blob/3c102c29c32f7c66c6e52e09b795fe9c061bbb03/src/open_clip/hf_model.py#L56
        """
        self.text_pooling_mode = text_pooling_mode
        self.vision_pooling_mode = vision_pooling_mode
        super().__init__(projection_dim, logit_scale_init_value, **kwargs)