Text Generation
Transformers
PyTorch
JAX
Safetensors
bloom
text-generation-inference
Inference Endpoints

Distributed Finetuning with Trainer

#16
by darragh - opened

Hi,
I am working on a docker instance with 4 X 40GB a100 cards. On a single 40GB card I am unable to fit a single sample through the model for finetuning, so I am trying to finetune with sharding to split the model layers across the cards.
I have my own script using the trainer, like below and execute it with python -m torch.distributed.launch --nproc_per_node 4.

When running the script I can see that the model is split across the 4 GPUs with the fsdp setting below. However, the per_device_train_batch_size loads 4 samples (1 per card) in each step so I get an OOM. I am wondering is it possible to only load one sample total in each step, instead of one sample per card.

Let me know if it is better to post this in the pytorch forums.

Thanks!

training_args = TrainingArguments(
    output_dir=args.outdir,
    overwrite_output_dir=True,
    save_total_limit=1,
    do_train=True,
    do_eval=False,
    do_predict=True,
    num_train_epochs=args.epochs,              # total number of training epochs
    per_device_train_batch_size=1,  # batch size per device during training
    per_device_eval_batch_size=1,   # batch size for evaluation
    gradient_accumulation_steps = 256,
    warmup_ratio=0.1,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir=args.logdir,            # directory for storing logs
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none",
    learning_rate=args.learning_rate,
    seed=99,
    local_rank=os.environ['LOCAL_RANK'],
    dataloader_num_workers = 16,
    gradient_checkpointing=True,
    lr_scheduler_type="cosine",
    fsdp='shard_grad_op', 
    fp16 = True if platform.system()!='Darwin' else False
)
BigScience Workshop org

Yeah it's probably more suitable to open an issue for this on https://github.com/huggingface/transformers, as it's not specific to the model

christopher changed discussion status to closed

Sign up or log in to comment