|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from typing import Optional |
|
|
|
from .configuration_bert import FlexBertConfig |
|
from .normalization import get_norm_layer |
|
from .initialization import ModuleType, init_weights |
|
|
|
|
|
class BertAlibiEmbeddings(nn.Module): |
|
"""Construct the embeddings for words, ignoring position. |
|
|
|
There are no positional embeddings since we use ALiBi and token_type |
|
embeddings. |
|
|
|
This module is modeled after the Hugging Face BERT's |
|
:class:`~transformers.model.bert.modeling_bert.BertEmbeddings`, but is |
|
modified as part of Mosaic BERT's ALiBi implementation. The key change is |
|
that position embeddings are removed. Position information instead comes |
|
from attention biases that scale linearly with the position distance |
|
between query and key tokens. |
|
|
|
This module ignores the `position_ids` input to the `forward` method. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
|
|
|
if getattr(config, "token_type_embeddings", True): |
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
|
self.use_token_type_embeddings = True |
|
else: |
|
self.use_token_type_embeddings = False |
|
|
|
self.LayerNorm = get_norm_layer(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
if self.use_token_type_embeddings: |
|
self.register_buffer( |
|
"token_type_ids", torch.zeros(config.max_position_embeddings, dtype=torch.long), persistent=False |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
past_key_values_length: int = 0, |
|
) -> torch.Tensor: |
|
if (input_ids is not None) == (inputs_embeds is not None): |
|
raise ValueError("Must specify either input_ids or input_embeds!") |
|
if input_ids is not None: |
|
input_shape = input_ids.size() |
|
else: |
|
assert inputs_embeds is not None |
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
seq_length = input_shape[1] |
|
|
|
if position_ids is None: |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
if self.use_token_type_embeddings and token_type_ids is None: |
|
if hasattr(self, "token_type_ids"): |
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length] |
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) |
|
token_type_ids = buffered_token_type_ids_expanded |
|
else: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
|
if self.use_token_type_embeddings: |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
embeddings = inputs_embeds + token_type_embeddings |
|
else: |
|
embeddings = inputs_embeds |
|
|
|
|
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class FlexBertEmbeddingsBase(nn.Module): |
|
"""A FlexBERT embeddings base class for type hints.""" |
|
|
|
def __init__(self, config: FlexBertConfig): |
|
super().__init__() |
|
self.config = config |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
raise NotImplementedError("This is a base class and should not be used directly.") |
|
|
|
def reset_parameters(self): |
|
self._init_weights(reset_params=True) |
|
|
|
def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: |
|
raise NotImplementedError("This is a base class and should not be used directly.") |
|
|
|
|
|
class FlexBertAbsoluteEmbeddings(FlexBertEmbeddingsBase): |
|
"""Construct the embeddings with absolute positional embeddings.""" |
|
|
|
def __init__(self, config: FlexBertConfig): |
|
super().__init__(config) |
|
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
|
self.norm = get_norm_layer(config) if config.embed_norm else nn.Identity() |
|
self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity() |
|
|
|
self.register_buffer( |
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False |
|
) |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb) |
|
init_weights(self.config, self.position_embeddings, type_of_module=ModuleType.emb) |
|
|
|
if reset_params: |
|
if self.config.embed_norm: |
|
self.norm.reset_parameters() |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
) -> torch.Tensor: |
|
if position_ids is None: |
|
position_ids = self.position_ids[:, 0 : input_ids.shape[1]] |
|
|
|
embeddings = self.tok_embeddings(input_ids) |
|
position_embeddings = self.position_embeddings(position_ids) |
|
|
|
embeddings = self.norm(embeddings + position_embeddings) |
|
return self.drop(embeddings) |
|
|
|
|
|
class FlexBertCompiledSansPositionEmbeddings(FlexBertEmbeddingsBase): |
|
"""Construct the embeddings from token embeddings without any positional embeddings.""" |
|
|
|
def __init__(self, config: FlexBertConfig): |
|
super().__init__(config) |
|
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
|
|
|
self.norm = get_norm_layer(config, compiled_norm=config.compile_model) if config.embed_norm else nn.Identity() |
|
self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity() |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb) |
|
|
|
if reset_params: |
|
if self.config.embed_norm: |
|
self.norm.reset_parameters() |
|
|
|
@torch.compile(dynamic=True) |
|
def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: |
|
return self.drop(self.norm(self.tok_embeddings(input_ids))) |
|
|
|
|
|
class FlexBertSansPositionEmbeddings(FlexBertEmbeddingsBase): |
|
"""Construct the embeddings from token embeddings without any positional embeddings.""" |
|
|
|
def __init__(self, config: FlexBertConfig): |
|
super().__init__(config) |
|
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
|
|
|
self.norm = get_norm_layer(config) if config.embed_norm else nn.Identity() |
|
self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity() |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb) |
|
|
|
if reset_params: |
|
if self.config.embed_norm: |
|
self.norm.reset_parameters() |
|
|
|
def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: |
|
return self.drop(self.norm(self.tok_embeddings(input_ids))) |
|
|
|
|
|
EBB2CLS = { |
|
"absolute_pos": FlexBertAbsoluteEmbeddings, |
|
"sans_pos": FlexBertSansPositionEmbeddings, |
|
} |
|
|
|
|
|
def get_embedding_layer(config: FlexBertConfig) -> FlexBertEmbeddingsBase: |
|
try: |
|
if config.compile_model and config.embedding_layer == "sans_pos": |
|
return FlexBertCompiledSansPositionEmbeddings(config) |
|
elif config.compile_model: |
|
raise ValueError(f"{config.compile_model=} only supports sans_pos embeddings.") |
|
return EBB2CLS[config.embedding_layer](config) |
|
except KeyError: |
|
raise ValueError(f"Invalid embeddings layer type: {config.embedding_layer=}, must be one of {EBB2CLS.keys()}.") |
|
|