mnoukhov's picture
Training in progress, step 500
1904ee8 verified
raw
history blame
5.61 kB
from dataclasses import dataclass, field
import torch
from accelerate import PartialState
from callbacks import PerplexityCallback
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from trl import DPOTrainer, ModelConfig
from trl.trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
@dataclass
class DPOScriptArguments:
dataset_name: str = field(default=None, metadata={"help": "the dataset name"})
dataset_train_split: str = field(default="train", metadata={"help": "the name of the training set of the dataset"})
dataset_eval_split: str = field(default="test", metadata={"help": "the name of the training set of the dataset"})
eval_dataset_name: str = field(default=None, metadata={"help": "the dataset name"})
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
max_length: int = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: int = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
sanity_check: bool = field(default=False, metadata={"help": "only train on 1000 samples"})
ignore_bias_buffers: bool = field(
default=False,
metadata={
"help": "debug argument for distributed training;"
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"})
gradient_checkpointing_use_reentrant: bool = field(
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
)
if __name__ == "__main__":
parser = HfArgumentParser((DPOScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
peft_config = get_peft_config(model_config)
if peft_config is None:
model_ref = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
else:
model_ref = None
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
################
# Dataset
################
train_dataset = load_dataset(args.dataset_name, split=args.dataset_train_split)
eval_dataset_name = args.eval_dataset_name if args.eval_dataset_name is not None else args.dataset_name
eval_dataset = load_dataset(eval_dataset_name, split=args.dataset_eval_split)
if args.sanity_check:
train_dataset = train_dataset.select(range(50))
eval_dataset = eval_dataset.select(range(50))
################
# Training
################
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
tokenizer=tokenizer,
beta=args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
max_length=args.max_length,
max_target_length=args.max_target_length,
max_prompt_length=args.max_prompt_length,
generate_during_eval=args.generate_during_eval,
peft_config=get_peft_config(model_config),
)
callback = PerplexityCallback(
args=training_args,
dataset=eval_dataset,
tokenizer=tokenizer,
accelerator=trainer.accelerator,
max_length=args.max_length,
max_prompt_length=args.max_prompt_length,
prompt_field="prompt",
target_field="chosen",
hub_model_id=training_args.hub_model_id,
)
trainer.add_callback(callback)
last_checkpoint = get_last_checkpoint(training_args.output_dir)
trainer.train(resume_from_checkpoint=last_checkpoint)
trainer.save_model(training_args.output_dir)
if PartialState().is_main_process:
# model = trainer.model.merge_and_unload()
trainer.push_to_hub(training_args.hub_model_id)
tokenizer.push_to_hub(training_args.hub_model_id)