File size: 1,422 Bytes
02a3b66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""
Subclasses VisionTextDualEncoderModel to customize text pooler.
"""

from typing import Optional

import torch
from transformers import AutoModel, VisionTextDualEncoderModel

from .configuration_custom_clip import CustomCLIPConfig, get_text_model_pooler


# @add_start_docstrings(CUSTOM_CLIP_START_DOCSTRING)
class CustomCLIPModel(VisionTextDualEncoderModel):
    config_class = CustomCLIPConfig

    DEFAULT_TEXT_MODEL_POOLER_TYPE: torch.nn.Module = get_text_model_pooler(
        CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_STR
    )
    DEFAULT_TEXT_MODEL_POOLER_KWARGS: dict = (
        CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_KWARGS
    )

    def __init__(
        self, config: Optional[CustomCLIPConfig.__base__] = None, *args, **kwargs
    ):
        config = config if config is None else CustomCLIPConfig.from_base(config)
        super().__init__(
            config,  # surprisingly, `super` is unnecessary, possibly due to implementation of CustomCLIPConfig.__init__?
            *args,
            **kwargs,
        )

        self.text_model.pooler = (
            (self.DEFAULT_TEXT_MODEL_POOLER_TYPE)(
                **self.DEFAULT_TEXT_MODEL_POOLER_KWARGS
            )
            if config is None
            else get_text_model_pooler(config.text_model_pooler)(
                **config.text_model_pooler_kwargs
            )
        )


AutoModel.register(CustomCLIPConfig, CustomCLIPModel)