|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import init |
|
|
|
from .configuration_bert import FlexBertConfig |
|
|
|
try: |
|
from flash_attn.ops.triton.layer_norm import RMSNorm as TritonRMSNorm |
|
from flash_attn.ops.triton.layer_norm import layer_norm_fn |
|
|
|
except ImportError: |
|
TritonRMSNorm = None |
|
layer_norm_fn = None |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
"""Llama2 RMSNorm implementation""" |
|
|
|
def __init__(self, dim: int, eps: float = 1e-5): |
|
""" |
|
Initialize the RMSNorm normalization layer. |
|
|
|
Args: |
|
dim (int): The dimension of the input tensor. |
|
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. |
|
|
|
Attributes: |
|
eps (float): A small value added to the denominator for numerical stability. |
|
weight (nn.Parameter): Learnable scaling parameter. |
|
|
|
""" |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
def _norm(self, x): |
|
""" |
|
Apply the RMSNorm normalization to the input tensor. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
|
|
Returns: |
|
torch.Tensor: The normalized tensor. |
|
|
|
""" |
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
def forward(self, x): |
|
""" |
|
Forward pass through the RMSNorm layer. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
|
|
Returns: |
|
torch.Tensor: The output tensor after applying RMSNorm. |
|
|
|
""" |
|
output = self._norm(x.float()).type_as(x) |
|
return output * self.weight |
|
|
|
def reset_parameters(self): |
|
init.ones_(self.weight) |
|
|
|
|
|
if layer_norm_fn is not None: |
|
|
|
class TritonLayerNorm(nn.LayerNorm): |
|
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): |
|
return layer_norm_fn( |
|
x, |
|
self.weight, |
|
self.bias, |
|
residual=residual, |
|
eps=self.eps, |
|
prenorm=prenorm, |
|
residual_in_fp32=residual_in_fp32, |
|
) |
|
else: |
|
TritonLayerNorm = None |
|
|
|
NORM2CLS = { |
|
"layernorm": nn.LayerNorm, |
|
"triton_layernorm": TritonLayerNorm if TritonLayerNorm is not None else nn.LayerNorm, |
|
"rmsnorm": RMSNorm, |
|
"triton_rmsnorm": TritonRMSNorm if TritonRMSNorm is not None else RMSNorm, |
|
} |
|
|
|
|
|
def get_norm_layer(config: FlexBertConfig, compiled_norm: bool = False) -> nn.Module: |
|
try: |
|
if compiled_norm: |
|
|
|
if config.normalization.startswith("triton_"): |
|
norm = config.normalization.replace("triton_", "") |
|
else: |
|
norm = config.normalization |
|
else: |
|
norm = config.normalization |
|
signature = inspect.signature(NORM2CLS[norm]) |
|
if hasattr(config, "norm_kwargs"): |
|
norm_kwargs = {k: v for k, v in config.norm_kwargs.items() if k in signature.parameters} |
|
else: |
|
norm_kwargs = {} |
|
return NORM2CLS[norm](config.hidden_size, **norm_kwargs) |
|
except KeyError: |
|
raise ValueError(f"Invalid normalization layer type: {config.normalization}, must be one of {NORM2CLS.keys()}.") |
|
|