mnoukhov's picture
Training in progress, step 500
1904ee8 verified
raw
history blame
4.43 kB
from dataclasses import dataclass, field
import torch
from accelerate import PartialState
from datasets import load_dataset
from tqdm.rich import tqdm
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments
from trl import ModelConfig, SFTTrainer
from trl.trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
tqdm.pandas()
def hh_combine(examples):
if isinstance(examples["chosen"], str):
return examples["prompt"] + examples["chosen"]
elif isinstance(examples["chosen"], list):
return list(map(str.__add__, examples["prompt"], examples["chosen"]))
else:
raise Exception(f"weird input examples of type {type(examples)}")
@dataclass
class ScriptArguments:
task_type: str = field(default="hh")
dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
dataset_train_split: str = field(default="train", metadata={"help": "the name of the training set of the dataset"})
dataset_eval_split: str = field(default="test", metadata={"help": "the name of the training set of the dataset"})
output_model_name: str = field(default="", metadata={"help": "model name to upload"})
max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"})
packing: bool = field(default=False, metadata={"help": "Whether to apply data packing or not during training"})
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
gradient_checkpointing_use_reentrant: bool = field(
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
)
sanity_check: bool = field(default=False)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
tokenizer.add_special_tokens({"pad_token": "<|padding|>"})
################
# Dataset
################
datasets = load_dataset(args.dataset_name)
if args.sanity_check:
for key in datasets:
datasets[key] = datasets[key].select(range(100))
training_args.push_to_hub = False
train_dataset = datasets[args.dataset_train_split]
eval_dataset = datasets[args.dataset_eval_split]
# train_dataset = train_dataset.map(lambda ex: {"text": ex['prompt'] + ex['chosen']})
# eval_dataset = eval_dataset.map(lambda ex: {"text": ex['prompt'] + ex['chosen']})
if args.task_type == "tldr":
formatting_func = None
dataset_text_field = "query_reference_response"
elif args.task_type == "hh":
formatting_func = hh_combine
dataset_text_field = None
################
# Training
################
trainer = SFTTrainer(
model=model_config.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
packing=args.packing,
formatting_func=formatting_func,
dataset_text_field=dataset_text_field,
peft_config=get_peft_config(model_config),
)
trainer.train()
trainer.save_model(training_args.output_dir)
if PartialState().is_main_process and model_config.use_peft:
model = trainer.model.merge_and_unload()
model.push_to_hub(args.output_model_name)
tokenizer.push_to_hub(args.output_model_name)