RITA_s / rita_configuration.py
DanielHesslow's picture
add model
7f52a15
raw
history blame
911 Bytes
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class RITAConfig(PretrainedConfig):
model_type = "codegen"
def __init__(
self,
vocab_size=128,
d_model=768,
num_layers=12,
max_seq_len=1024,
num_heads=12,
dropout=0.,
ff_ratio=4,
bos_token_id=50256, # TODO
eos_token_id=50256, # TODO
**kwargs,
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.num_heads = num_heads
self.d_feedforward = d_model*ff_ratio
self.num_layers = num_layers
self.max_seq_len=max_seq_len
self.dropout = dropout
self.bos_token_id=bos_token_id,
self.eos_token_id=eos_token_id