|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
import torch |
|
from accelerate import DistributedDataParallelKwargs |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer, HfArgumentParser, pipeline |
|
|
|
from trl import PPOConfig, PPOTrainer, set_seed |
|
from trl.core import LengthSampler |
|
from trl.models.modeling_value_model import AutoModelForCausalLMWithValueModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
""" |
|
The name of the Casual LM model we wish to fine with PPO |
|
""" |
|
|
|
|
|
|
|
model_name: Optional[str] = field(default="", metadata={"help": "the model name"}) |
|
reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"}) |
|
gold_reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"}) |
|
|
|
dataset_name: Optional[str] = field( |
|
default="CarperAI/openai_summarize_tldr", metadata={"help": "the dataset name"} |
|
) |
|
train_split: Optional[str] = field( |
|
default="train", metadata={"help": "the dataset split to evaluate on; default to 'none' (no evaluation)"} |
|
) |
|
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"}) |
|
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) |
|
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) |
|
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) |
|
ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"}) |
|
gradient_accumulation_steps: Optional[int] = field( |
|
default=4, metadata={"help": "the number of gradient accumulation steps"} |
|
) |
|
adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"}) |
|
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"}) |
|
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"}) |
|
reward_baseline: Optional[float] = field( |
|
default=0.0, |
|
metadata={"help": "a baseline value that is subtracted from the reward"}, |
|
) |
|
batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"}) |
|
eval_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"}) |
|
save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"}) |
|
output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"}) |
|
seed: Optional[int] = field(default=0, metadata={"help": "the seed"}) |
|
steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"}) |
|
init_kl_coef: Optional[float] = field( |
|
default=0.05, |
|
metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"}, |
|
) |
|
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"}) |
|
|
|
|
|
|
|
|
|
output_min_length: Optional[int] = field(default=24, metadata={"help": "the batch size"}) |
|
output_max_length: Optional[int] = field(default=48, metadata={"help": "the batch size"}) |
|
input_max_length: Optional[int] = field(default=512, metadata={"help": "maximum length for generation"}) |
|
|
|
|
|
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) |
|
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) |
|
bf16: Optional[bool] = field( |
|
default=False, |
|
metadata={ |
|
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU." |
|
}, |
|
) |
|
fp16: Optional[bool] = field( |
|
default=False, |
|
metadata={ |
|
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU." |
|
}, |
|
) |
|
fp16_model: Optional[bool] = field( |
|
default=False, |
|
metadata={}, |
|
) |
|
|
|
|
|
use_lora: Optional[bool] = field( |
|
default=True, |
|
) |
|
lora_alpha: Optional[float] = field(default=32, metadata={"help": "the lora alpha parameter"}) |
|
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) |
|
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) |
|
lora_all_linear: Optional[bool] = field(default=False, metadata={"help": "lora adapter on all linear layers"}) |
|
|
|
|
|
|
|
|
|
|
|
def create_and_prepare_model(args): |
|
if args.bf16: |
|
torch_dtype = torch.bfloat16 |
|
elif args.fp16_model: |
|
torch_dtype = torch.float16 |
|
else: |
|
torch_dtype = torch.float32 |
|
|
|
model = AutoModelForCausalLMWithValueModel.from_pretrained( |
|
args.model_name, |
|
args.reward_model_name, |
|
torch_dtype=torch_dtype, |
|
) |
|
|
|
|
|
|
|
|
|
model.config.torch_dtype = torch_dtype |
|
model.config.use_cache = True |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name) |
|
if getattr(tokenizer, "pad_token", None) is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
if getattr(model.config, "pad_token_id", None) is None: |
|
model.config.pad_token_id = model.config.eos_token_id |
|
|
|
return model, tokenizer |
|
|
|
|
|
def create_and_prepare_dataset(args, tokenizer, split, num_proc=2): |
|
dataset = load_dataset(args.dataset_name, split=split) |
|
|
|
dataset = dataset.rename_column("prompt", "query") |
|
original_columns = dataset.column_names |
|
original_columns.remove("query") |
|
|
|
dataset = dataset.map( |
|
lambda examples: tokenizer(examples["query"], truncation=True, max_length=args.input_max_length), |
|
batched=True, |
|
num_proc=num_proc, |
|
remove_columns=original_columns, |
|
) |
|
|
|
dataset.set_format("torch") |
|
return dataset |
|
|
|
|
|
def collator(data): |
|
return dict((key, [d[key] for d in data]) for key in data[0]) |
|
|
|
|
|
parser = HfArgumentParser(ScriptArguments) |
|
script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0] |
|
config = PPOConfig( |
|
steps=script_args.steps, |
|
model_name=script_args.model_name, |
|
learning_rate=script_args.learning_rate, |
|
log_with=script_args.log_with, |
|
batch_size=script_args.batch_size, |
|
mini_batch_size=script_args.mini_batch_size, |
|
gradient_accumulation_steps=script_args.gradient_accumulation_steps, |
|
optimize_cuda_cache=True, |
|
early_stopping=script_args.early_stopping, |
|
target_kl=script_args.target_kl, |
|
ppo_epochs=script_args.ppo_epochs, |
|
seed=script_args.seed, |
|
init_kl_coef=script_args.init_kl_coef, |
|
adap_kl_ctrl=script_args.adap_kl_ctrl, |
|
accelerator_kwargs={"kwargs_handlers": [DistributedDataParallelKwargs(find_unused_parameters=False)]}, |
|
) |
|
|
|
|
|
set_seed(config.seed) |
|
|
|
model, tokenizer = create_and_prepare_model(script_args) |
|
train_dataset = create_and_prepare_dataset(script_args, tokenizer, script_args.train_split) |
|
|
|
|
|
|
|
ppo_trainer = PPOTrainer( |
|
config, |
|
model, |
|
ref_model=None, |
|
tokenizer=tokenizer, |
|
dataset=train_dataset, |
|
data_collator=collator, |
|
) |
|
|
|
model.eval() |
|
|
|
|
|
device = ppo_trainer.accelerator.device |
|
if ppo_trainer.accelerator.num_processes == 1: |
|
device = 0 if torch.cuda.is_available() else "cpu" |
|
|
|
reward_pipe = pipeline( |
|
"sentiment-analysis", |
|
model=script_args.reward_model_name, |
|
|
|
tokenizer=tokenizer, |
|
return_token_type_ids=False, |
|
) |
|
if script_args.eval_freq is not None: |
|
gold_reward_pipe = pipeline( |
|
"sentiment-analysis", |
|
model=script_args.gold_reward_model_name, |
|
|
|
|
|
tokenizer=tokenizer, |
|
return_token_type_ids=False, |
|
) |
|
sent_kwargs = { |
|
"top_k": None, |
|
"function_to_apply": "none", |
|
"batch_size": 16, |
|
"truncation": True, |
|
} |
|
|
|
|
|
|
|
generation_kwargs = { |
|
"min_length": -1, |
|
"top_k": 0.0, |
|
"top_p": 1.0, |
|
"do_sample": True, |
|
"pad_token_id": tokenizer.pad_token_id, |
|
"eos_token_id": tokenizer.eos_token_id, |
|
} |
|
output_length_sampler = LengthSampler(script_args.output_min_length, script_args.output_max_length) |
|
|
|
for epoch, batch in tqdm( |
|
enumerate(ppo_trainer.dataloader), |
|
total=config.total_ppo_epochs, |
|
disable=not ppo_trainer.accelerator.is_local_main_process, |
|
): |
|
if epoch >= config.total_ppo_epochs: |
|
break |
|
|
|
question_tensors = batch["input_ids"] |
|
|
|
query_response_tensors = ppo_trainer.generate( |
|
question_tensors, |
|
return_prompt=True, |
|
length_sampler=output_length_sampler, |
|
**generation_kwargs, |
|
) |
|
response_tensors = [tensor[len(question) :] for tensor, question in zip(query_response_tensors, question_tensors)] |
|
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) |
|
|
|
|
|
texts = [q + r for q, r in zip(batch["query"], batch["response"])] |
|
reward_inputs = tokenizer( |
|
texts, padding=True, truncation=True, return_tensors="pt", return_token_type_ids=False |
|
).to(ppo_trainer.accelerator.device) |
|
|
|
|
|
|
|
for i, tensor in enumerate(query_response_tensors): |
|
if not torch.equal(tensor, reward_inputs["input_ids"][i][: reward_inputs["attention_mask"][i].sum()]): |
|
|
|
|
|
import pdb |
|
|
|
pdb.set_trace() |
|
|
|
pipe_outputs = reward_pipe(texts, **sent_kwargs) |
|
rewards = [torch.tensor(output[0]["score"]) for output in pipe_outputs] |
|
|
|
|
|
|
|
|
|
|
|
stats = ppo_trainer.step(question_tensors, response_tensors, rewards) |
|
ppo_trainer.log_stats(stats, batch, rewards) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if script_args.eval_freq and epoch % script_args.eval_freq == 0: |
|
if ppo_trainer.accelerator.is_main_process: |
|
pipe_outputs = gold_reward_pipe(texts, **sent_kwargs) |
|
rewards = [torch.tensor(output[0]["score"]) for output in pipe_outputs] |
|
logs = {} |
|
logs["env/gold_reward_mean"] = torch.mean(rewards).cpu().numpy().item() |
|
logs["env/gold_reward_std"] = torch.std(rewards).cpu().numpy().item() |
|
logs["env/gold_reward_dist"] = rewards.cpu().numpy() |
|
ppo_trainer.accelerator.log(logs) |
|
print(logs) |
|
|
|
if script_args.save_freq and epoch and epoch % script_args.save_freq == 0: |
|
ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}") |
|
|