File size: 861 Bytes
7f52a15 ee739e3 7f52a15 ee739e3 7f52a15 ee739e3 89919e7 7f52a15 ee739e3 7f52a15 89919e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class RITAConfig(PretrainedConfig):
model_type = "rita"
def __init__(
self,
vocab_size=26,
d_model=768,
num_layers=12,
max_seq_len=1024,
num_heads=12,
dropout=0.,
ff_ratio=4,
eos_token_id=2,
initializer_range=0.02,
**kwargs,
):
super().__init__(eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.num_heads = num_heads
self.d_feedforward = d_model*ff_ratio
self.num_layers = num_layers
self.max_seq_len=max_seq_len
self.dropout = dropout
self.eos_token_id=eos_token_id
self.initializer_range=0.02
|