""" Subclasses VisionTextDualEncoderModel to customize text pooler. """ from typing import Optional import torch from transformers import AutoModel, VisionTextDualEncoderModel from .configuration_custom_clip import CustomCLIPConfig, get_text_model_pooler # @add_start_docstrings(CUSTOM_CLIP_START_DOCSTRING) class CustomCLIPModel(VisionTextDualEncoderModel): config_class = CustomCLIPConfig DEFAULT_TEXT_MODEL_POOLER_TYPE: torch.nn.Module = get_text_model_pooler( CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_STR ) DEFAULT_TEXT_MODEL_POOLER_KWARGS: dict = ( CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_KWARGS ) def __init__( self, config: Optional[CustomCLIPConfig.__base__] = None, *args, **kwargs ): config = config if config is None else CustomCLIPConfig.from_base(config) super().__init__( config, # surprisingly, `super` is unnecessary, possibly due to implementation of CustomCLIPConfig.__init__? *args, **kwargs, ) self.text_model.pooler = ( (self.DEFAULT_TEXT_MODEL_POOLER_TYPE)( **self.DEFAULT_TEXT_MODEL_POOLER_KWARGS ) if config is None else get_text_model_pooler(config.text_model_pooler)( **config.text_model_pooler_kwargs ) ) AutoModel.register(CustomCLIPConfig, CustomCLIPModel)