|
from typing import Callable |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
from torch.nn import functional as F |
|
|
|
|
|
def ensure_tuple(val: int | tuple[int, ...], n: int = 2) -> tuple[int, ...]: |
|
if isinstance(val, int): |
|
return (val,) * n |
|
elif len(val) != n: |
|
raise ValueError(f"Expected a tuple of {n} values, but got {len(val)}: {val}") |
|
return val |
|
|
|
|
|
def use_fused_attn(): |
|
if hasattr(F, "scaled_dot_product_attention"): |
|
return True |
|
return False |
|
|
|
|
|
class QuickGELU(nn.Module): |
|
""" |
|
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs |
|
""" |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return input * torch.sigmoid(1.702 * input) |
|
|
|
|
|
def get_act_layer(name: str) -> Callable[[], nn.Module]: |
|
match name: |
|
case "gelu": |
|
return nn.GELU |
|
case "quick_gelu": |
|
return QuickGELU |
|
case _: |
|
raise ValueError(f"Activation layer {name} not supported.") |
|
|