|
from transformers import T5Config |
|
|
|
POSITION_ENCODING_REL_T5_BIAS = "t5_relative_bias" |
|
POSITION_ENCODING_REL_TRANSFORMER_XL = "transformer_xl_relative_encoding" |
|
POSITION_ENCODING_ROTARY = "rotary" |
|
POSITION_ENCODING_ROTARY_RERUN = "rotary_rerun" |
|
POSITION_ENCODING_ROTARY_NEW = "new_rotary" |
|
POSITION_ENCODING_ABS_LEARNED = "abs_learned" |
|
POSITION_ENCODING_ABS_SINUSOID = "abs_sinusoid" |
|
POSITION_ENCODING_ALiBi = "alibi" |
|
POSITION_ENCODING_ALiBi_LEARNED = "alibi_learned" |
|
POSITION_ENCODING_NONE = "none" |
|
POSITION_ENCODING_NONE_WINDOW = "none_window" |
|
|
|
|
|
class CustomT5Config(T5Config): |
|
model_type = "custom_decoder_only_t5" |
|
|
|
def __init__( |
|
self, |
|
position_encoding_type=POSITION_ENCODING_REL_T5_BIAS, |
|
**kwargs, |
|
): |
|
if position_encoding_type not in [ |
|
POSITION_ENCODING_ALiBi, |
|
POSITION_ENCODING_ALiBi_LEARNED, |
|
POSITION_ENCODING_ABS_LEARNED, |
|
POSITION_ENCODING_ABS_SINUSOID, |
|
POSITION_ENCODING_REL_T5_BIAS, |
|
POSITION_ENCODING_REL_TRANSFORMER_XL, |
|
POSITION_ENCODING_ROTARY, |
|
POSITION_ENCODING_ROTARY_NEW, |
|
POSITION_ENCODING_NONE, |
|
POSITION_ENCODING_NONE_WINDOW, |
|
]: |
|
raise ValueError( |
|
f"Invalid position_encoding_type: {position_encoding_type}" |
|
) |
|
self.position_encoding_type = position_encoding_type |
|
super().__init__(**kwargs) |
|
|