japanese-clip-vit-h-14-bert-base / configuration_custom_clip.py
bsyx001's picture
Upload model
8b26dd7 verified
from copy import deepcopy
from typing import Optional
import torch
from transformers import AutoConfig, VisionTextDualEncoderConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class CustomCLIPPooler(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0, :]
return first_token_tensor
def get_text_model_pooler(text_model_pooler: str) -> torch.nn.Module:
if text_model_pooler == "CustomCLIPPooler":
return CustomCLIPPooler
else:
raise ValueError(f"Unrecognized text model pooler type {text_model_pooler!r}.")
def is_valid_text_model_pooler(
text_model_pooler: str, suppress_error: bool = False
) -> bool:
try:
get_text_model_pooler(text_model_pooler)
except ValueError:
if not suppress_error:
raise
return False
else:
return True
class CustomCLIPConfig(VisionTextDualEncoderConfig):
model_type = "custom-clip-model"
DEFAULT_TEXT_MODEL_POOLER_STR: str = "CustomCLIPPooler"
DEFAULT_TEXT_MODEL_POOLER_KWARGS: dict = {}
def __init__(
self,
*args,
text_model_pooler: Optional[str] = None,
text_model_pooler_kwargs: Optional[dict] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.text_model_pooler = (
self.DEFAULT_TEXT_MODEL_POOLER_STR
if text_model_pooler is None
else text_model_pooler
)
is_valid_text_model_pooler(self.text_model_pooler, suppress_error=False)
self.text_model_pooler_kwargs = (
self.DEFAULT_TEXT_MODEL_POOLER_KWARGS
if text_model_pooler_kwargs is None
else text_model_pooler_kwargs
)
@classmethod
def from_base(cls, obj: VisionTextDualEncoderConfig):
if not isinstance(obj, cls):
base = VisionTextDualEncoderConfig
if not isinstance(obj, base):
raise TypeError(f"obj must be of type {cls!r} or {base!r}.")
obj = deepcopy(obj)
logger.warning(f"Changing config class from {obj.__class__!r} to {cls!r}.")
obj.__class__ = cls
def setattr_with_warning(object, name, value):
logger.warning(f"Setting {name!r} to {value!r}.")
setattr(object, name, value)
setattr_with_warning(
obj, "text_model_pooler", cls.DEFAULT_TEXT_MODEL_POOLER_STR
)
setattr_with_warning(
obj, "text_model_pooler_kwargs", cls.DEFAULT_TEXT_MODEL_POOLER_KWARGS
)
return obj
AutoConfig.register(CustomCLIPConfig.model_type, CustomCLIPConfig)