pere's picture
Commit from model create scripts
e8c8bae
from __gin__ import dynamic_registration
import __main__ as train_script
import seqio
from t5.data import mixtures
from t5x import adafactor
from t5x.examples.t5 import network
from t5x import gin_utils
from t5x import models
from t5x import partitioning
from t5x import trainer
from t5x import utils
import tasks
# Macros:
# ==============================================================================
BATCH_SIZE = 128
DROPOUT_RATE = 0.1
EVAL_PERIOD = 1000
EVAL_STEPS = 20
EVALUATOR_NUM_EXAMPLES = None
EVALUATOR_USE_MEMORY_CACHE = True
INITIAL_CHECKPOINT_PATH = \
'gs://nb-t5x-us-central2/norwegian_NCC_plus_English_pluss200k_balanced_bokmaal_nynorsk_t5x_base/checkpoint_1700000'
JSON_WRITE_N_RESULTS = None
LABEL_SMOOTHING = 0.0
LOSS_NORMALIZING_FACTOR = None
MIXTURE_OR_TASK_MODULE = None
MIXTURE_OR_TASK_NAME = 'translate_long'
MODEL = @models.EncoderDecoderModel()
MODEL_DIR = 'gs://nb-t5x-us-central2/finetuned/nynorsk_balanced_base_long_v1'
OPTIMIZER = @adafactor.Adafactor()
RANDOM_SEED = 0
TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 512}
TRAIN_STEPS = 1705000
USE_CACHED_TASKS = False
USE_HARDWARE_RNG = False
VOCABULARY = @seqio.SentencePieceVocabulary()
Z_LOSS = 0.0001
# Parameters for adafactor.Adafactor:
# ==============================================================================
adafactor.Adafactor.decay_rate = 0.8
adafactor.Adafactor.logical_factor_rules = \
@adafactor.standard_logical_factor_rules()
adafactor.Adafactor.step_offset = 0
# Parameters for utils.CheckpointConfig:
# ==============================================================================
utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig()
utils.CheckpointConfig.save = @utils.SaveCheckpointConfig()
# Parameters for utils.create_learning_rate_scheduler:
# ==============================================================================
utils.create_learning_rate_scheduler.base_learning_rate = 0.001
utils.create_learning_rate_scheduler.factors = 'constant'
utils.create_learning_rate_scheduler.warmup_steps = 1000
# Parameters for infer_eval/utils.DatasetConfig:
# ==============================================================================
infer_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE
infer_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
infer_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
infer_eval/utils.DatasetConfig.pack = False
infer_eval/utils.DatasetConfig.seed = 42
infer_eval/utils.DatasetConfig.shuffle = False
infer_eval/utils.DatasetConfig.split = 'validation'
infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
infer_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
# Parameters for train/utils.DatasetConfig:
# ==============================================================================
train/utils.DatasetConfig.batch_size = %BATCH_SIZE
train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
train/utils.DatasetConfig.pack = True
train/utils.DatasetConfig.seed = None
train/utils.DatasetConfig.shuffle = True
train/utils.DatasetConfig.split = 'train'
train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
train/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
# Parameters for train_eval/utils.DatasetConfig:
# ==============================================================================
train_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE
train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
train_eval/utils.DatasetConfig.pack = True
train_eval/utils.DatasetConfig.seed = 42
train_eval/utils.DatasetConfig.shuffle = False
train_eval/utils.DatasetConfig.split = 'validation'
train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
train_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
# Parameters for models.EncoderDecoderModel:
# ==============================================================================
models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
models.EncoderDecoderModel.module = @network.Transformer()
models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
models.EncoderDecoderModel.z_loss = %Z_LOSS
# Parameters for seqio.Evaluator:
# ==============================================================================
seqio.Evaluator.logger_cls = \
[@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
seqio.Evaluator.num_examples = %EVALUATOR_NUM_EXAMPLES
seqio.Evaluator.use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
# Parameters for seqio.JSONLogger:
# ==============================================================================
seqio.JSONLogger.write_n_results = %JSON_WRITE_N_RESULTS
# Parameters for partitioning.PjitPartitioner:
# ==============================================================================
partitioning.PjitPartitioner.logical_axis_rules = \
@partitioning.standard_logical_axis_rules()
partitioning.PjitPartitioner.model_parallel_submesh = None
partitioning.PjitPartitioner.num_partitions = 1
# Parameters for utils.RestoreCheckpointConfig:
# ==============================================================================
utils.RestoreCheckpointConfig.dtype = 'float32'
utils.RestoreCheckpointConfig.mode = 'specific'
utils.RestoreCheckpointConfig.path = %INITIAL_CHECKPOINT_PATH
# Parameters for utils.SaveCheckpointConfig:
# ==============================================================================
utils.SaveCheckpointConfig.dtype = 'float32'
utils.SaveCheckpointConfig.keep = None
utils.SaveCheckpointConfig.period = 1000
utils.SaveCheckpointConfig.save_dataset = False
# Parameters for seqio.SentencePieceVocabulary:
# ==============================================================================
seqio.SentencePieceVocabulary.sentencepiece_model_file = \
'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model'
# Parameters for network.T5Config:
# ==============================================================================
network.T5Config.dropout_rate = %DROPOUT_RATE
network.T5Config.dtype = 'bfloat16'
network.T5Config.emb_dim = 768
network.T5Config.head_dim = 64
network.T5Config.logits_via_embedding = False
network.T5Config.mlp_activations = ('gelu', 'linear')
network.T5Config.mlp_dim = 2048
network.T5Config.num_decoder_layers = 12
network.T5Config.num_encoder_layers = 12
network.T5Config.num_heads = 12
network.T5Config.vocab_size = 250112
# Parameters for train_script.train:
# ==============================================================================
train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
train_script.train.eval_period = %EVAL_PERIOD
train_script.train.eval_steps = %EVAL_STEPS
train_script.train.infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
train_script.train.inference_evaluator_cls = @seqio.Evaluator
train_script.train.model = %MODEL
train_script.train.model_dir = %MODEL_DIR
train_script.train.partitioner = @partitioning.PjitPartitioner()
train_script.train.random_seed = %RANDOM_SEED
train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config
train_script.train.total_steps = %TRAIN_STEPS
train_script.train.train_dataset_cfg = @train/utils.DatasetConfig()
train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
train_script.train.trainer_cls = @trainer.Trainer
train_script.train.use_hardware_rng = %USE_HARDWARE_RNG
# Parameters for trainer.Trainer:
# ==============================================================================
trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler()
trainer.Trainer.num_microbatches = None
# Parameters for network.Transformer:
# ==============================================================================
network.Transformer.config = @network.T5Config()