from copy import deepcopy from typing import Optional import torch from transformers import AutoConfig, VisionTextDualEncoderConfig from transformers.utils import logging logger = logging.get_logger(__name__) class CustomCLIPPooler(torch.nn.Module): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: first_token_tensor = hidden_states[:, 0, :] return first_token_tensor def get_text_model_pooler(text_model_pooler: str) -> torch.nn.Module: if text_model_pooler == "CustomCLIPPooler": return CustomCLIPPooler else: raise ValueError(f"Unrecognized text model pooler type {text_model_pooler!r}.") def is_valid_text_model_pooler( text_model_pooler: str, suppress_error: bool = False ) -> bool: try: get_text_model_pooler(text_model_pooler) except ValueError: if not suppress_error: raise return False else: return True class CustomCLIPConfig(VisionTextDualEncoderConfig): model_type = "custom-clip-model" DEFAULT_TEXT_MODEL_POOLER_STR: str = "CustomCLIPPooler" DEFAULT_TEXT_MODEL_POOLER_KWARGS: dict = {} def __init__( self, *args, text_model_pooler: Optional[str] = None, text_model_pooler_kwargs: Optional[dict] = None, **kwargs, ): super().__init__(*args, **kwargs) self.text_model_pooler = ( self.DEFAULT_TEXT_MODEL_POOLER_STR if text_model_pooler is None else text_model_pooler ) is_valid_text_model_pooler(self.text_model_pooler, suppress_error=False) self.text_model_pooler_kwargs = ( self.DEFAULT_TEXT_MODEL_POOLER_KWARGS if text_model_pooler_kwargs is None else text_model_pooler_kwargs ) @classmethod def from_base(cls, obj: VisionTextDualEncoderConfig): if not isinstance(obj, cls): base = VisionTextDualEncoderConfig if not isinstance(obj, base): raise TypeError(f"obj must be of type {cls!r} or {base!r}.") obj = deepcopy(obj) logger.warning(f"Changing config class from {obj.__class__!r} to {cls!r}.") obj.__class__ = cls def setattr_with_warning(object, name, value): logger.warning(f"Setting {name!r} to {value!r}.") setattr(object, name, value) setattr_with_warning( obj, "text_model_pooler", cls.DEFAULT_TEXT_MODEL_POOLER_STR ) setattr_with_warning( obj, "text_model_pooler_kwargs", cls.DEFAULT_TEXT_MODEL_POOLER_KWARGS ) return obj AutoConfig.register(CustomCLIPConfig.model_type, CustomCLIPConfig)