All parameters finetune,OOM

#71
by Saicy - opened

I use 8*a100 finetune mistral-7b model,with deepspeed3
When I use about 5k data to trian model,It works. However,when I add more data in above config(about 1.7w),it reports OOM.

I don't change the batchsize,lr or anything.Why it reports OOM
Please help!

maybe show yout code?

maybe show yout code?


from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo

def main():
from huggingface_hub import login
login(token ="xxxxxxxxx")
model_args, data_args, training_args, finetuning_args, general_args = get_train_args()

if general_args.stage == "pt":
    run_pt(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "sft":
    run_sft(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "rm":
    run_rm(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "ppo":
    run_ppo(model_args, data_args, training_args, finetuning_args)

def _mp_fn(index):
# For xla_spawn (TPUs)
main()

Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py

from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import deepspeed
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer

def run_sft(
model_args: ModelArguments,
data_args: DataArguments,
training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
):

from huggingface_hub import login
login(token ="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")
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")

#deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)

# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \
            training_args.generation_max_length is not None else data_args.max_target_length
training_args.generation_num_beams = data_args.eval_num_beams if \
            data_args.eval_num_beams is not None else training_args.generation_num_beams

# Initialize our Trainer
trainer = Seq2SeqPeftTrainer(
    finetuning_args=finetuning_args,
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=callbacks,
    compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
    **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
)

# Keyword arguments for `model.generate`
gen_kwargs = {
    "do_sample": True,
    "top_p": 0.7,
    "max_new_tokens": data_args.max_target_length + 1,
    "temperature": 0.3,
    "logits_processor": get_logits_processor()
}

# Training
if training_args.do_train:
    train_result = trainer.train()
    trainer.log_metrics("train", train_result.metrics)
    trainer.save_metrics("train", train_result.metrics)
    trainer.save_state()
    trainer.save_model()
    if trainer.is_world_process_zero() and model_args.plot_loss:
        plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

# Evaluation
if training_args.do_eval:
    metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
    if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
        metrics.pop("eval_loss", None)
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

# Predict
if training_args.do_predict:
    predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
    if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
        predict_results.metrics.pop("predict_loss", None)
    trainer.log_metrics("predict", predict_results.metrics)
    trainer.save_metrics("predict", predict_results.metrics)
    trainer.save_predictions(predict_results)

if name == "main":
main()

Sign up or log in to comment