|
|
|
from typing import Literal |
|
from transformers import AutoConfig |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.models.auto import CONFIG_MAPPING |
|
from transformers.models.mistral import MistralConfig |
|
|
|
NVEMBED_TYPE = "nvembed" |
|
LATENT_ATTENTION_TYPE = "latent_attention" |
|
BIDIR_MISTRAL_TYPE = "bidir_mistral" |
|
|
|
class NVEmbedConfig(PretrainedConfig): |
|
model_type = "nvembed" |
|
is_composition = False |
|
|
|
def __init__( |
|
self, |
|
hidden_size=4096, |
|
latent_attention_config=None, |
|
text_config=None, |
|
padding_side: Literal["right", "left"]="right", |
|
add_pad_token: bool=True, |
|
is_mask_instruction: bool = True, |
|
add_eos: bool=True, |
|
mask_type: str="b", |
|
**kwargs, |
|
): |
|
if isinstance(latent_attention_config, dict): |
|
latent_attention_config["model_type"] = ( |
|
latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE |
|
) |
|
latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config) |
|
elif latent_attention_config is None: |
|
latent_attention_config = CONFIG_MAPPING[LATENT_ATTENTION_TYPE]() |
|
|
|
self.latent_attention_config = latent_attention_config |
|
|
|
if isinstance(text_config, dict): |
|
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" |
|
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) |
|
elif text_config is None: |
|
text_config = None |
|
|
|
self.hidden_size = hidden_size |
|
self.text_config = text_config |
|
self.padding_side = padding_side |
|
self.is_mask_instruction = is_mask_instruction |
|
self.add_pad_token = add_pad_token |
|
self.add_eos = add_eos |
|
self.mask_type = mask_type |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
class LatentAttentionConfig(PretrainedConfig): |
|
model_type = LATENT_ATTENTION_TYPE |
|
is_composition = False |
|
_name_or_path = "latent_attention" |
|
|
|
def __init__( |
|
self, |
|
num_latents_value: int=512, |
|
num_cross_heads: int=8, |
|
output_normalize: bool=True, |
|
hidden_dim: int=4096, |
|
latent_dim: int=4096, |
|
cross_dim_head: int=4096, |
|
**kwargs, |
|
): |
|
self.num_latents_value = num_latents_value |
|
self.num_cross_heads = num_cross_heads |
|
self.output_normalize = output_normalize |
|
self.hidden_dim = hidden_dim |
|
self.latent_dim = latent_dim |
|
self.cross_dim_head = cross_dim_head |
|
|
|
|
|
class BidirectionalMistralConfig(MistralConfig): |
|
model_type = BIDIR_MISTRAL_TYPE |
|
keys_to_ignore_at_inference = ["past_key_values"] |
|
|
|
AutoConfig.register(NVEMBED_TYPE, NVEmbedConfig) |
|
AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig) |
|
AutoConfig.register(BIDIR_MISTRAL_TYPE, BidirectionalMistralConfig) |
|
|
|
NVEmbedConfig.register_for_auto_class() |
|
LatentAttentionConfig.register_for_auto_class() |
|
BidirectionalMistralConfig.register_for_auto_class() |
|
|