FlexBert / activation.py
NohTow's picture
Using dict as input
ce9aa51
# Copyright 2024 **AUTHORS_TODO**
# License: Apache-2.0
# Copyright 2020 The HuggingFace Team.
# License: Apache-2.0
from collections import OrderedDict
from typing import Union
import torch.nn as nn
from .configuration_bert import FlexBertConfig
class ClassInstantier(OrderedDict):
def __getitem__(self, key):
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)
ACT2CLS = {
"celu": nn.CELU,
"elu": nn.ELU,
"gelu": nn.GELU,
"gelu_tanh": (nn.GELU, {"approximate": "tanh"}),
"hardtanh": nn.Hardtanh,
"hardsigmoid": nn.Hardsigmoid,
"hardshrink": nn.Hardshrink,
"hardswish": nn.Hardswish,
"leaky_relu": nn.LeakyReLU,
"logsigmoid": nn.LogSigmoid,
"mish": nn.Mish,
"prelu": nn.PReLU,
"relu": nn.ReLU,
"relu6": nn.ReLU6,
"rrelu": nn.RReLU,
"selu": nn.SELU,
"sigmoid": nn.Sigmoid,
"silu": nn.SiLU,
"softmin": nn.Softmin,
"softplus": nn.Softplus,
"softshrink": nn.Softshrink,
"softsign": nn.Softsign,
"swish": nn.SiLU,
"tanh": nn.Tanh,
"tanhshrink": nn.Tanhshrink,
"threshold": nn.Threshold,
}
ACT2FN = ClassInstantier(ACT2CLS)
def get_act_fn(config: Union[FlexBertConfig, str]) -> nn.Module:
try:
if isinstance(config, str):
return ACT2FN[config]
return ACT2FN[config.hidden_act]
except KeyError:
if isinstance(config, str):
raise ValueError(f"Invalid activation function type: {config}, must be one of {ACT2FN.keys()}.")
else:
raise ValueError(f"Invalid activation function type: {config.hidden_act=}, must be one of {ACT2FN.keys()}.")