|
import os |
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
import bitsandbytes as bnb |
|
import torch |
|
from accelerate import Accelerator |
|
from datasets import load_dataset |
|
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM, LoraConfig |
|
from tqdm import tqdm |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoModelForSeq2SeqLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
GPT2Model, |
|
HfArgumentParser, |
|
TrainingArguments, |
|
) |
|
from transformers.pytorch_utils import Conv1D |
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
|
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer |
|
|
|
|
|
tqdm.pandas() |
|
|
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
""" |
|
The name of the Casual LM model we wish to fine with SFTTrainer |
|
""" |
|
|
|
model_name: Optional[str] = field(default="EleutherAI/pythia-6.9b-deduped", metadata={"help": "the model name"}) |
|
tokenizer_name: Optional[str] = field(default=None, metadata={"help": "the model name"}) |
|
dataset_name: Optional[str] = field( |
|
default="CarperAI/openai_summarize_tldr", metadata={"help": "the dataset name"} |
|
) |
|
train_split: Optional[str] = field( |
|
default="train", metadata={"help": "the dataset split to evaluate on; default to 'none' (no evaluation)"} |
|
) |
|
eval_split: Optional[str] = field( |
|
default="test", |
|
metadata={"help": "the dataset split to evaluate on; default to 'none' (no evaluation)"}, |
|
) |
|
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"}) |
|
streaming: Optional[bool] = field(default=False, metadata={"help": "whether to stream the dataset"}) |
|
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"}) |
|
|
|
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"}) |
|
lr_scheduler_type: Optional[str] = field(default="cosine") |
|
num_warmup_steps: Optional[int] = field(default=100) |
|
weight_decay: Optional[float] = field(default=0.05) |
|
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) |
|
|
|
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"}) |
|
num_train_epochs: Optional[int] = field(default=1, metadata={"help": "the number of training epochs"}) |
|
per_device_train_batch_size: Optional[int] = field( |
|
default=16, metadata={"help": "the per device train batch size"} |
|
) |
|
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "the per device eval batch size"}) |
|
gradient_accumulation_steps: Optional[int] = field( |
|
default=16, metadata={"help": "the number of gradient accumulation steps"} |
|
) |
|
gradient_checkpointing: Optional[bool] = field( |
|
default=False, metadata={"help": "whether to use gradient checkpointing"} |
|
) |
|
seq_length: Optional[int] = field(default=560, metadata={"help": "Input sequence length"}) |
|
|
|
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "load the model in 8 bits precision"}) |
|
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) |
|
use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"}) |
|
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) |
|
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) |
|
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) |
|
trust_remote_code: Optional[bool] = field(default=True, metadata={"help": "Enable `trust_remote_code`"}) |
|
bf16: Optional[bool] = field(default=True) |
|
fp16_model: Optional[bool] = field( |
|
default=False, |
|
metadata={}, |
|
) |
|
fp16: Optional[bool] = field( |
|
default=False, |
|
metadata={ |
|
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU." |
|
}, |
|
) |
|
train_completions: Optional[bool] = field(default=False) |
|
packing: Optional[bool] = field(default=True) |
|
|
|
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"}) |
|
output_model_name: Optional[str] = field(default=None, metadata={"help": "the model pushed to hub"}) |
|
logging_steps: Optional[int] = field(default=10, metadata={"help": "the number of logging steps"}) |
|
eval_steps: Optional[int] = field(default=1000, metadata={"help": "the number of steps to eval at"}) |
|
save_steps: Optional[int] = field(default=1000, metadata={"help": "the number of steps to save at"}) |
|
save_strategy: Optional[str] = field(default="steps") |
|
seed: Optional[int] = field(default=0) |
|
just_eval: Optional[bool] = field(default=False) |
|
resume_from_checkpoint: Optional[str] = field(default=None) |
|
|
|
|
|
def chars_token_ratio(dataset, tokenizer, nb_examples=400): |
|
""" |
|
Estimate the average number of characters per token in the dataset. |
|
""" |
|
total_characters, total_tokens = 0, 0 |
|
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): |
|
text = prepare_sample_text(example) |
|
total_characters += len(text) |
|
if tokenizer.is_fast: |
|
total_tokens += len(tokenizer(text).tokens()) |
|
else: |
|
total_tokens += len(tokenizer.tokenize(text)) |
|
|
|
return total_characters / total_tokens |
|
|
|
|
|
def prepare_sample_text(examples): |
|
if isinstance(examples["chosen"], str): |
|
return examples["prompt"] + examples["chosen"] |
|
elif isinstance(examples["chosen"], list): |
|
return list(map(str.__add__, examples["prompt"], examples["chosen"])) |
|
else: |
|
raise Exception(f"weird input examples of type {type(examples)}") |
|
|
|
|
|
def create_datasets(args): |
|
train_data = load_dataset( |
|
args.dataset_name, |
|
split=args.train_split, |
|
streaming=args.streaming, |
|
) |
|
|
|
if args.streaming: |
|
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) |
|
|
|
valid_data = load_dataset( |
|
args.dataset_name, |
|
split=args.eval_split, |
|
) |
|
return train_data, valid_data |
|
|
|
|
|
def create_model(args): |
|
print("Loading the model") |
|
if args.load_in_8bit and args.load_in_4bit: |
|
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") |
|
elif args.load_in_8bit or args.load_in_4bit: |
|
quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit) |
|
device_map = {"": Accelerator().local_process_index} |
|
else: |
|
device_map = None |
|
quantization_config = None |
|
|
|
if args.bf16: |
|
torch_dtype = torch.bfloat16 |
|
elif args.fp16_model: |
|
torch_dtype = torch.float16 |
|
else: |
|
torch_dtype = None |
|
|
|
|
|
|
|
|
|
if "t5" in args.model_name: |
|
model_cls = AutoModelForSeq2SeqLM |
|
else: |
|
model_cls = AutoModelForCausalLM |
|
|
|
model = model_cls.from_pretrained( |
|
args.model_name, |
|
quantization_config=quantization_config, |
|
device_map=device_map, |
|
trust_remote_code=args.trust_remote_code, |
|
torch_dtype=torch_dtype, |
|
|
|
token=True, |
|
) |
|
model.config.torch_dtype = torch_dtype |
|
model.config.use_cache = False |
|
|
|
print("Loading dataset") |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name if args.tokenizer_name is None else args.tokenizer_name) |
|
if getattr(tokenizer, "pad_token", None) is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
return model, tokenizer |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = HfArgumentParser(ScriptArguments) |
|
args = parser.parse_args_into_dataclasses()[0] |
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
model, tokenizer = create_model(args) |
|
|
|
train_dataset, eval_dataset = create_datasets(args) |
|
|
|
if args.train_completions: |
|
data_collator = DataCollatorForCompletionOnlyLM(tokenizer=tokenizer, response_template="TL;DR:") |
|
else: |
|
data_collator = None |
|
|
|
training_args = TrainingArguments( |
|
output_dir=args.output_dir, |
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
per_device_eval_batch_size=args.per_device_eval_batch_size, |
|
dataloader_drop_last=True, |
|
evaluation_strategy="steps", |
|
max_steps=args.max_steps, |
|
num_train_epochs=args.num_train_epochs, |
|
eval_steps=args.eval_steps, |
|
save_steps=args.save_steps, |
|
save_strategy=args.save_strategy, |
|
logging_steps=args.logging_steps, |
|
learning_rate=args.learning_rate, |
|
lr_scheduler_type=args.lr_scheduler_type, |
|
warmup_steps=args.num_warmup_steps, |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
gradient_checkpointing=args.gradient_checkpointing, |
|
bf16=args.bf16, |
|
fp16=args.fp16, |
|
weight_decay=args.weight_decay, |
|
report_to=args.log_with, |
|
optim=args.optimizer_type, |
|
remove_unused_columns=False, |
|
disable_tqdm=False, |
|
seed=args.seed, |
|
|
|
ddp_find_unused_parameters=(args.gradient_checkpointing), |
|
) |
|
|
|
if args.use_peft: |
|
peft_config = LoraConfig( |
|
r=args.lora_r, |
|
lora_alpha=args.lora_alpha, |
|
lora_dropout=args.lora_dropout, |
|
target_modules="all-linear", |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
else: |
|
peft_config = None |
|
|
|
chars_per_token = chars_token_ratio(train_dataset, tokenizer) |
|
print(f"The character to token ratio of the train dataset is: {chars_per_token:.2f}") |
|
|
|
print("Starting main loop") |
|
trainer = SFTTrainer( |
|
model=model, |
|
tokenizer=tokenizer, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
peft_config=peft_config, |
|
max_seq_length=args.seq_length, |
|
formatting_func=prepare_sample_text, |
|
packing=args.packing, |
|
chars_per_token=chars_per_token, |
|
data_collator=data_collator, |
|
) |
|
|
|
if args.use_peft: |
|
trainer.model.print_trainable_parameters() |
|
|
|
if not args.just_eval: |
|
if args.resume_from_checkpoint is not None: |
|
last_checkpoint = args.resume_from_checkpoint |
|
else: |
|
|
|
last_checkpoint = get_last_checkpoint(args.output_dir) |
|
|
|
print("Training...") |
|
trainer.train(resume_from_checkpoint=last_checkpoint) |
|
|
|
trainer.evaluate() |
|
|
|
print("Saving last checkpoint of the model") |
|
output_dir = os.path.join(args.output_dir, "final_model") |
|
trainer.save_model(output_dir) |
|
|
|
if args.use_peft: |
|
output_dir = os.path.join(args.output_dir, "final_adapter_checkpoint") |
|
trainer.model.save_pretrained(output_dir) |
|
|
|
|
|
del model |
|
torch.cuda.empty_cache() |
|
|
|
if "t5" in args.model_name: |
|
model_cls = AutoPeftModelForSeq2SeqLM |
|
else: |
|
model_cls = AutoPeftModelForCausalLM |
|
|
|
|
|
|
|
|
|
model = trainer.model.merge_and_unload() |
|
|
|
output_merged_dir = os.path.join(args.output_dir, "final_merged_checkpoint") |
|
model.save_pretrained(output_merged_dir, safe_serialization=True) |
|
|
|
if args.output_model_name is not None: |
|
model.push_to_hub(args.output_model_name) |
|
|
|
else: |
|
results = trainer.evaluate() |
|
print(results) |
|
|