japanese-clip-vit-h-14-bert-deeper / modeling_custom_clip.py
bsyx001's picture
Upload model
02a3b66 verified
raw
history blame
1.42 kB
"""
Subclasses VisionTextDualEncoderModel to customize text pooler.
"""
from typing import Optional
import torch
from transformers import AutoModel, VisionTextDualEncoderModel
from .configuration_custom_clip import CustomCLIPConfig, get_text_model_pooler
# @add_start_docstrings(CUSTOM_CLIP_START_DOCSTRING)
class CustomCLIPModel(VisionTextDualEncoderModel):
config_class = CustomCLIPConfig
DEFAULT_TEXT_MODEL_POOLER_TYPE: torch.nn.Module = get_text_model_pooler(
CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_STR
)
DEFAULT_TEXT_MODEL_POOLER_KWARGS: dict = (
CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_KWARGS
)
def __init__(
self, config: Optional[CustomCLIPConfig.__base__] = None, *args, **kwargs
):
config = config if config is None else CustomCLIPConfig.from_base(config)
super().__init__(
config, # surprisingly, `super` is unnecessary, possibly due to implementation of CustomCLIPConfig.__init__?
*args,
**kwargs,
)
self.text_model.pooler = (
(self.DEFAULT_TEXT_MODEL_POOLER_TYPE)(
**self.DEFAULT_TEXT_MODEL_POOLER_KWARGS
)
if config is None
else get_text_model_pooler(config.text_model_pooler)(
**config.text_model_pooler_kwargs
)
)
AutoModel.register(CustomCLIPConfig, CustomCLIPModel)