|
import os |
|
import shutil |
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
import torch |
|
from accelerate import Accelerator |
|
from datasets import load_dataset |
|
from peft import AutoPeftModelForSequenceClassification, PeftConfig |
|
from transformers import ( |
|
AutoModelForSequenceClassification, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
DataCollatorWithPadding, |
|
HfArgumentParser, |
|
Trainer, |
|
TrainingArguments, |
|
) |
|
|
|
|
|
shutil.disk_usage = lambda x: shutil._ntuple_diskusage(1, 1, 1) |
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
output_dir: Optional[str] = field( |
|
default="/home/toolkit/huggingface/openai_summarize_tldr_reward", |
|
metadata={"help": "output folder"}, |
|
) |
|
model_name: Optional[str] = field( |
|
default="mnoukhov/pythia410m-tldr-sft-rm-adapter", metadata={"help": "the model name"} |
|
) |
|
new_column_name: Optional[str] = field(default="reward_baseline") |
|
dataset_name: Optional[str] = field( |
|
default="mnoukhov/openai_summarize_comparisons_tldrprompt", metadata={"help": "the dataset name"} |
|
) |
|
max_length: Optional[int] = field(default=560, metadata={"help": "maximum length for generation"}) |
|
train_split: Optional[str] = field(default="train[:20]", metadata={"help": "the dataset name"}) |
|
eval_split: Optional[str] = field(default=None, metadata={"help": "the dataset name"}) |
|
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) |
|
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) |
|
batch_size: Optional[int] = field(default=4) |
|
bf16: Optional[bool] = field(default=False) |
|
fp16: Optional[bool] = field(default=False) |
|
fp16_model: Optional[bool] = field(default=False) |
|
|
|
|
|
def create_and_prepare_model(args): |
|
if args.load_in_8bit and args.load_in_4bit: |
|
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") |
|
elif args.load_in_8bit or args.load_in_4bit: |
|
quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit) |
|
device_map = {"": Accelerator().local_process_index} |
|
else: |
|
device_map = None |
|
quantization_config = None |
|
|
|
if args.bf16: |
|
torch_dtype = torch.bfloat16 |
|
elif args.fp16_model: |
|
torch_dtype = torch.float16 |
|
else: |
|
torch_dtype = None |
|
|
|
if "adapter" in args.model_name: |
|
model_cls = AutoPeftModelForSequenceClassification |
|
config = PeftConfig.from_pretrained(args.model_name) |
|
tokenizer_name = config.base_model_name_or_path |
|
else: |
|
model_cls = AutoModelForSequenceClassification |
|
tokenizer_name = args.model_name |
|
|
|
model = model_cls.from_pretrained( |
|
args.model_name, |
|
quantization_config=quantization_config, |
|
device_map=device_map, |
|
num_labels=1, |
|
torch_dtype=torch_dtype, |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
if getattr(tokenizer, "pad_token", None) is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
if getattr(model.config, "pad_token_id", None) is None: |
|
model.config.pad_token_id = model.config.eos_token_id |
|
|
|
return model, tokenizer |
|
|
|
|
|
def create_and_prepare_dataset(args, tokenizer, split, num_proc=2): |
|
dataset = load_dataset(args.dataset_name, split=split) |
|
|
|
def combine_and_tokenize(examples): |
|
if isinstance(examples["label"], str): |
|
texts = examples["prompt"] + examples["label"] |
|
else: |
|
texts = [prompt + label for prompt, label in zip(examples["prompt"], examples["label"])] |
|
|
|
return tokenizer(texts, truncation=True, padding=False, max_length=args.max_length) |
|
|
|
original_columns = dataset["train"].column_names |
|
|
|
dataset = dataset.map( |
|
combine_and_tokenize, |
|
batched=True, |
|
num_proc=num_proc, |
|
remove_columns=original_columns, |
|
) |
|
|
|
dataset.set_format("torch") |
|
return dataset |
|
|
|
|
|
def strip_prompt(examples): |
|
examples["prompt"] = [prompt.strip() for prompt in examples["prompt"]] |
|
|
|
return examples |
|
|
|
|
|
parser = HfArgumentParser(ScriptArguments) |
|
script_args = parser.parse_args_into_dataclasses()[0] |
|
|
|
model, tokenizer = create_and_prepare_model(script_args) |
|
|
|
training_args = TrainingArguments( |
|
output_dir=script_args.output_dir, |
|
per_device_eval_batch_size=script_args.batch_size, |
|
bf16=script_args.bf16, |
|
fp16=script_args.fp16, |
|
) |
|
|
|
if script_args.fp16: |
|
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) |
|
else: |
|
data_collator = None |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
tokenizer=tokenizer, |
|
data_collator=data_collator, |
|
) |
|
|
|
data_splits = { |
|
"train": script_args.train_split, |
|
"valid": script_args.eval_split, |
|
} |
|
|
|
original_datasets = create_and_prepare_dataset(script_args, tokenizer, split=data_splits) |
|
|
|
augmented_dataset = load_dataset(script_args.dataset_name, split=data_splits) |
|
augmented_dataset = augmented_dataset.map(strip_prompt, batched=True) |
|
|
|
for key, dataset in original_datasets.items(): |
|
preds = trainer.predict(dataset) |
|
reward_preds = preds[0].flatten() |
|
|
|
if trainer.accelerator.is_local_main_process: |
|
augmented_dataset[key] = augmented_dataset[key].add_column(script_args.new_column_name, reward_preds) |
|
|
|
trainer.accelerator.wait_for_everyone() |
|
if trainer.accelerator.is_main_process: |
|
|
|
augmented_dataset.push_to_hub(os.path.basename(script_args.output_dir)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|