from transformers import PretrainedConfig from transformers.utils import logging from transformers.models.esm import EsmConfig from transformers.models.bert import BertConfig logger = logging.get_logger(__name__) class ProtSTConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ProtSTModel`]. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: protein_config (`dict`, *optional*): Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`]. text_config (`dict`, *optional*): Dictionary of configuration options used to initialize [`BertForPubMed`]. ```""" model_type = "protst" def __init__( self, protein_config=None, text_config=None, **kwargs, ): super().__init__(**kwargs) if protein_config is None: protein_config = {} logger.info("`protein_config` is `None`. Initializing the `ProtSTTextConfig` with default values.") if text_config is None: text_config = {} logger.info("`text_config` is `None`. Initializing the `ProtSTVisionConfig` with default values.") self.protein_config = EsmConfig(**protein_config) self.text_config = BertConfig(**text_config) @classmethod def from_protein_text_configs( cls, protein_config: EsmConfig, text_config: BertConfig, **kwargs ): r""" Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns: [`ProtSTConfig`]: An instance of a configuration object """ return cls(protein_config=protein_config.to_dict(), text_config=text_config.to_dict(), **kwargs)