|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .configuration_bert import FlexBertConfig |
|
from .activation import get_act_fn |
|
from .normalization import get_norm_layer |
|
from .initialization import ModuleType, init_weights |
|
|
|
|
|
class BertResidualGLU(nn.Module): |
|
"""Applies the FFN at the end of each Mosaic BERT layer. |
|
|
|
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
|
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but |
|
introduces Gated Linear Units. |
|
|
|
Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a |
|
standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with |
|
`config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed |
|
with the `config.intermediate_size=3072`. |
|
However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased |
|
parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__() |
|
self.config = config |
|
self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False) |
|
self.act = get_act_fn(config.hidden_act) |
|
self.wo = nn.Linear(config.intermediate_size, config.hidden_size) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.layernorm = get_norm_layer(config) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
"""Compute new hidden states from current hidden states. |
|
|
|
Args: |
|
hidden_states (torch.Tensor): The (unpadded) hidden states from |
|
the attention layer [nnz, dim]. |
|
""" |
|
residual_connection = hidden_states |
|
|
|
hidden_states = self.gated_layers(hidden_states) |
|
gated = hidden_states[:, : self.config.intermediate_size] |
|
non_gated = hidden_states[:, self.config.intermediate_size :] |
|
hidden_states = self.act(gated) * non_gated |
|
hidden_states = self.dropout(hidden_states) |
|
|
|
hidden_states = self.wo(hidden_states) |
|
|
|
hidden_states = self.layernorm(hidden_states + residual_connection) |
|
return hidden_states |
|
|
|
|
|
class FlexBertMLPBase(nn.Module): |
|
"""A FlexBERT MLP base class for type hints.""" |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__() |
|
self.config = config |
|
self.layer_id = layer_id |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
raise NotImplementedError("This is a base class and should not be used directly.") |
|
|
|
def reset_parameters(self): |
|
self._init_weights(reset_params=True) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
raise NotImplementedError("This is a base class and should not be used directly.") |
|
|
|
|
|
class FlexBertMLP(FlexBertMLPBase): |
|
"""Applies the MLP at the end of each FlexBERT layer. |
|
|
|
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
|
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. |
|
""" |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__(config=config, layer_id=layer_id) |
|
self.Wi = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_in_bias) |
|
self.act = get_act_fn(config.hidden_act) |
|
self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() |
|
self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
init_weights( |
|
self.config, |
|
self.Wi, |
|
layer_dim=self.config.hidden_size, |
|
layer_id=None, |
|
type_of_module=ModuleType.in_module, |
|
) |
|
init_weights( |
|
self.config, |
|
self.Wo, |
|
layer_dim=self.config.intermediate_size, |
|
layer_id=self.layer_id, |
|
type_of_module=ModuleType.out_module, |
|
) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
"""Compute new hidden states from current hidden states. |
|
|
|
Args: |
|
hidden_states (torch.Tensor): The (unpadded) hidden states from |
|
the attention layer [nnz, dim]. |
|
""" |
|
return self.Wo(self.drop(self.act(self.Wi(hidden_states)))) |
|
|
|
|
|
class FlexBertGLU(FlexBertMLPBase): |
|
"""Applies the GLU at the end of each FlexBERT layer. |
|
|
|
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
|
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. |
|
""" |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__(config=config, layer_id=layer_id) |
|
self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_in_bias) |
|
self.act = get_act_fn(config.hidden_act) |
|
self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() |
|
self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
init_weights( |
|
self.config, |
|
self.Wi, |
|
layer_dim=self.config.hidden_size, |
|
layer_id=None, |
|
type_of_module=ModuleType.in_module, |
|
) |
|
init_weights( |
|
self.config, |
|
self.Wo, |
|
layer_dim=self.config.intermediate_size, |
|
layer_id=self.layer_id, |
|
type_of_module=ModuleType.out_module, |
|
) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
input, gate = self.Wi(hidden_states).chunk(2, dim=-1) |
|
return self.Wo(self.drop(self.act(input) * gate)) |
|
|
|
|
|
class FlexBertParallelGLU(FlexBertMLPBase): |
|
"""Applies the GLU at the end of each FlexBERT layer using intermediate_ff computed in parallel of the attention. |
|
|
|
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
|
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. |
|
""" |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__(config=config, layer_id=layer_id) |
|
self.act = get_act_fn(config.hidden_act) |
|
self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() |
|
self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
init_weights( |
|
self.config, |
|
self.Wo, |
|
layer_dim=self.config.intermediate_size, |
|
layer_id=self.layer_id, |
|
type_of_module=ModuleType.out_module, |
|
) |
|
|
|
def forward(self, intermediate_ff: torch.Tensor) -> torch.Tensor: |
|
input, gate = intermediate_ff.chunk(2, dim=-1) |
|
return self.Wo(self.drop(self.act(input) * gate)) |
|
|
|
|
|
MLP2CLS = { |
|
"mlp": FlexBertMLP, |
|
"glu": FlexBertGLU, |
|
"parallel_glu": FlexBertParallelGLU, |
|
} |
|
|
|
|
|
def get_mlp_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertMLPBase: |
|
try: |
|
mlp_layer = ( |
|
config.initial_mlp_layer |
|
if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None |
|
else config.mlp_layer |
|
) |
|
return MLP2CLS[mlp_layer](config, layer_id=layer_id) |
|
except KeyError as e: |
|
if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None: |
|
raise ValueError( |
|
f"Invalid MLP layer type: {config.initial_mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}" |
|
) |
|
else: |
|
raise ValueError(f"Invalid MLP layer type: {config.mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}") |
|
|