|
#! /bin/bash |
|
|
|
|
|
export WANDB_API_KEY='' |
|
|
|
|
|
export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE' |
|
|
|
|
|
python3 -m EasyLM.models.llama.llama_train \ |
|
--jax_distributed.initialize_jax_distributed=True \ |
|
--mesh_dim='1,-1,1' \ |
|
--dtype='bf16' \ |
|
--total_steps=1000000 \ |
|
--eval_freq=50000 \ |
|
--log_freq=1000 \ |
|
--save_model_freq=2000 \ |
|
--save_milestone_freq=50000 \ |
|
--load_llama_config='3b' \ |
|
--update_llama_config='' \ |
|
--load_dataset_state='' \ |
|
--load_checkpoint='' \ |
|
--tokenizer.pretrained_model_name_or_path='./' \ |
|
--optimizer.type='lion' \ |
|
--optimizer.lion_optimizer.weight_decay=1.0 \ |
|
--optimizer.lion_optimizer.lr=3e-5 \ |
|
--optimizer.lion_optimizer.end_lr=3e-6 \ |
|
--optimizer.lion_optimizer.lr_warmup_steps=2000 \ |
|
--optimizer.lion_optimizer.lr_decay_steps=1000000 \ |
|
--optimizer.lion_optimizer.bf16_momentum=True \ |
|
--train_dataset.type='huggingface' \ |
|
--train_dataset.text_processor.fields='text' \ |
|
--train_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_v2_filtered' \ |
|
--train_dataset.huggingface_dataset.split='train' \ |
|
--train_dataset.huggingface_dataset.seq_length=2048 \ |
|
--train_dataset.huggingface_dataset.batch_size=64 \ |
|
--eval_dataset.type='huggingface' \ |
|
--eval_dataset.text_processor.fields='text' \ |
|
--eval_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_v2_filtered' \ |
|
--eval_dataset.huggingface_dataset.split='validation' \ |
|
--eval_dataset.huggingface_dataset.seq_length=2048 \ |
|
--eval_dataset.huggingface_dataset.batch_size=64 \ |
|
--checkpointer.save_optimizer_state=True \ |
|
--logger.online=True \ |
|
--logger.prefix='EasyLM' \ |
|
--logger.project="open_llama_3b" \ |
|
--logger.output_dir="gs://finnish-nlp-research-us/llama-3b-checkpoint" \ |
|
--logger.wandb_dir="./" |
|
|
|
|