|
from dataclasses import dataclass, field |
|
|
|
import torch |
|
from accelerate import PartialState |
|
from callbacks import PerplexityCallback |
|
from datasets import load_dataset |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments |
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
|
from trl import DPOTrainer, ModelConfig |
|
from trl.trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config |
|
|
|
|
|
@dataclass |
|
class DPOScriptArguments: |
|
dataset_name: str = field(default=None, metadata={"help": "the dataset name"}) |
|
dataset_train_split: str = field(default="train", metadata={"help": "the name of the training set of the dataset"}) |
|
dataset_eval_split: str = field(default="test", metadata={"help": "the name of the training set of the dataset"}) |
|
eval_dataset_name: str = field(default=None, metadata={"help": "the dataset name"}) |
|
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) |
|
max_length: int = field(default=512, metadata={"help": "max length of each sample"}) |
|
max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"}) |
|
max_target_length: int = field( |
|
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"} |
|
) |
|
sanity_check: bool = field(default=False, metadata={"help": "only train on 1000 samples"}) |
|
ignore_bias_buffers: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "debug argument for distributed training;" |
|
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" |
|
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" |
|
}, |
|
) |
|
generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"}) |
|
gradient_checkpointing_use_reentrant: bool = field( |
|
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"} |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = HfArgumentParser((DPOScriptArguments, TrainingArguments, ModelConfig)) |
|
args, training_args, model_config = parser.parse_args_into_dataclasses() |
|
|
|
if training_args.gradient_checkpointing: |
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) |
|
|
|
|
|
|
|
|
|
torch_dtype = ( |
|
model_config.torch_dtype |
|
if model_config.torch_dtype in ["auto", None] |
|
else getattr(torch, model_config.torch_dtype) |
|
) |
|
quantization_config = get_quantization_config(model_config) |
|
model_kwargs = dict( |
|
revision=model_config.model_revision, |
|
trust_remote_code=model_config.trust_remote_code, |
|
attn_implementation=model_config.attn_implementation, |
|
torch_dtype=torch_dtype, |
|
use_cache=False if training_args.gradient_checkpointing else True, |
|
device_map=get_kbit_device_map() if quantization_config is not None else None, |
|
quantization_config=quantization_config, |
|
) |
|
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs) |
|
peft_config = get_peft_config(model_config) |
|
if peft_config is None: |
|
model_ref = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs) |
|
else: |
|
model_ref = None |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) |
|
if tokenizer.pad_token_id is None: |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
if args.ignore_bias_buffers: |
|
|
|
model._ddp_params_and_buffers_to_ignore = [ |
|
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool |
|
] |
|
|
|
|
|
|
|
|
|
train_dataset = load_dataset(args.dataset_name, split=args.dataset_train_split) |
|
eval_dataset_name = args.eval_dataset_name if args.eval_dataset_name is not None else args.dataset_name |
|
eval_dataset = load_dataset(eval_dataset_name, split=args.dataset_eval_split) |
|
|
|
if args.sanity_check: |
|
train_dataset = train_dataset.select(range(50)) |
|
eval_dataset = eval_dataset.select(range(50)) |
|
|
|
|
|
|
|
|
|
trainer = DPOTrainer( |
|
model, |
|
model_ref, |
|
args=training_args, |
|
tokenizer=tokenizer, |
|
beta=args.beta, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
max_length=args.max_length, |
|
max_target_length=args.max_target_length, |
|
max_prompt_length=args.max_prompt_length, |
|
generate_during_eval=args.generate_during_eval, |
|
peft_config=get_peft_config(model_config), |
|
) |
|
|
|
callback = PerplexityCallback( |
|
args=training_args, |
|
dataset=eval_dataset, |
|
tokenizer=tokenizer, |
|
accelerator=trainer.accelerator, |
|
max_length=args.max_length, |
|
max_prompt_length=args.max_prompt_length, |
|
prompt_field="prompt", |
|
target_field="chosen", |
|
hub_model_id=training_args.hub_model_id, |
|
) |
|
|
|
trainer.add_callback(callback) |
|
|
|
last_checkpoint = get_last_checkpoint(training_args.output_dir) |
|
trainer.train(resume_from_checkpoint=last_checkpoint) |
|
|
|
trainer.save_model(training_args.output_dir) |
|
|
|
if PartialState().is_main_process: |
|
|
|
trainer.push_to_hub(training_args.hub_model_id) |
|
tokenizer.push_to_hub(training_args.hub_model_id) |
|
|