|
from transformers import VisionTextDualEncoderConfig |
|
|
|
class VTDEConfig(VisionTextDualEncoderConfig): |
|
model_type = "vtde" |
|
def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, |
|
text_pooling_mode='mean', |
|
vision_pooling_mode='max', |
|
**kwargs): |
|
""" |
|
pooling_mode in ['mean', 'max', 'cls'] |
|
https://arxiv.org/pdf/2210.09996.pdf |
|
https://github.com/kahnchana/clippy/blob/3c102c29c32f7c66c6e52e09b795fe9c061bbb03/src/open_clip/hf_model.py#L56 |
|
""" |
|
self.text_pooling_mode = text_pooling_mode |
|
self.vision_pooling_mode = vision_pooling_mode |
|
super().__init__(projection_dim, logit_scale_init_value, **kwargs) |
|
|