pythia410m-sft-tldr / code /rl_training_value_model.py
mnoukhov's picture
Training in progress, step 500
1904ee8 verified
raw
history blame
13.1 kB
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
import torch
from accelerate import DistributedDataParallelKwargs
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser, pipeline
from trl import PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
from trl.models.modeling_value_model import AutoModelForCausalLMWithValueModel
# import copy
# from torch_ema import ExponentialMovingAverage
# from transforers import pipeline
# tqdm.pandas()
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with PPO
"""
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
# models like gpt-neo* models are more suitable.
model_name: Optional[str] = field(default="", metadata={"help": "the model name"})
reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"})
gold_reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"})
# tokenizer_name: Optional[str] = field(default=None, metadata={"help": "the tokenizer name"})
dataset_name: Optional[str] = field(
default="CarperAI/openai_summarize_tldr", metadata={"help": "the dataset name"}
)
train_split: Optional[str] = field(
default="train", metadata={"help": "the dataset split to evaluate on; default to 'none' (no evaluation)"}
)
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"})
gradient_accumulation_steps: Optional[int] = field(
default=4, metadata={"help": "the number of gradient accumulation steps"}
)
adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"})
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
reward_baseline: Optional[float] = field(
default=0.0,
metadata={"help": "a baseline value that is subtracted from the reward"},
)
batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"})
eval_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"})
save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"})
output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"})
seed: Optional[int] = field(default=0, metadata={"help": "the seed"})
steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"})
init_kl_coef: Optional[float] = field(
default=0.05,
metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
)
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
# multi_adapter_value: Optional[bool] = field(default=False)
# separate_reward_model: Optional[str] = field(default=None, metadata={"help": "the reward model name"})
# Generation
output_min_length: Optional[int] = field(default=24, metadata={"help": "the batch size"})
output_max_length: Optional[int] = field(default=48, metadata={"help": "the batch size"})
input_max_length: Optional[int] = field(default=512, metadata={"help": "maximum length for generation"})
# Quantization
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
bf16: Optional[bool] = field(
default=False,
metadata={
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
},
)
fp16: Optional[bool] = field(
default=False,
metadata={
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
},
)
fp16_model: Optional[bool] = field(
default=False,
metadata={},
)
# LoRA
use_lora: Optional[bool] = field(
default=True,
)
lora_alpha: Optional[float] = field(default=32, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
lora_all_linear: Optional[bool] = field(default=False, metadata={"help": "lora adapter on all linear layers"})
# # EMA stuff
# ema_decay: Optional[float] = field(default=0.995, metadata={"help": "the ema decay rate"})
# reset_freq: Optional[int] = field(default=None, metadata={"help": "reset every n epochs"})
def create_and_prepare_model(args):
if args.bf16:
torch_dtype = torch.bfloat16
elif args.fp16_model:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
model = AutoModelForCausalLMWithValueModel.from_pretrained(
args.model_name,
args.reward_model_name,
torch_dtype=torch_dtype,
)
# if script_args.ignore_bias_buffers:
# torch distributed hack
model.config.torch_dtype = torch_dtype
model.config.use_cache = True
# model.config.use_cache = not args.gradient_checkpointing
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
if getattr(model.config, "pad_token_id", None) is None:
model.config.pad_token_id = model.config.eos_token_id
return model, tokenizer
def create_and_prepare_dataset(args, tokenizer, split, num_proc=2):
dataset = load_dataset(args.dataset_name, split=split)
dataset = dataset.rename_column("prompt", "query")
original_columns = dataset.column_names
original_columns.remove("query")
dataset = dataset.map(
lambda examples: tokenizer(examples["query"], truncation=True, max_length=args.input_max_length),
batched=True,
num_proc=num_proc,
remove_columns=original_columns,
)
dataset.set_format("torch")
return dataset
def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
parser = HfArgumentParser(ScriptArguments)
script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0]
config = PPOConfig(
steps=script_args.steps,
model_name=script_args.model_name,
learning_rate=script_args.learning_rate,
log_with=script_args.log_with,
batch_size=script_args.batch_size,
mini_batch_size=script_args.mini_batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
optimize_cuda_cache=True,
early_stopping=script_args.early_stopping,
target_kl=script_args.target_kl,
ppo_epochs=script_args.ppo_epochs,
seed=script_args.seed,
init_kl_coef=script_args.init_kl_coef,
adap_kl_ctrl=script_args.adap_kl_ctrl,
accelerator_kwargs={"kwargs_handlers": [DistributedDataParallelKwargs(find_unused_parameters=False)]},
)
# set seed before initializing value head for deterministic eval
set_seed(config.seed)
model, tokenizer = create_and_prepare_model(script_args)
train_dataset = create_and_prepare_dataset(script_args, tokenizer, script_args.train_split)
# eval_dataset = create_and_prepare_dataset(script_args, tokenizer, script_args.eval_split)
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(
config,
model,
ref_model=None,
tokenizer=tokenizer,
dataset=train_dataset,
data_collator=collator,
)
model.eval()
# if script_args.separate_reward_model:
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a ` pipeline` bug
reward_pipe = pipeline(
"sentiment-analysis",
model=script_args.reward_model_name,
# device_map={"": Accelerator().local_process_index},
tokenizer=tokenizer,
return_token_type_ids=False,
)
if script_args.eval_freq is not None:
gold_reward_pipe = pipeline(
"sentiment-analysis",
model=script_args.gold_reward_model_name,
# device_map={"": Accelerator().local_process_index},
# model_kwargs={"load_in_8bit": True},
tokenizer=tokenizer,
return_token_type_ids=False,
)
sent_kwargs = {
"top_k": None,
"function_to_apply": "none",
"batch_size": 16,
"truncation": True,
}
# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
output_length_sampler = LengthSampler(script_args.output_min_length, script_args.output_max_length)
for epoch, batch in tqdm(
enumerate(ppo_trainer.dataloader),
total=config.total_ppo_epochs,
disable=not ppo_trainer.accelerator.is_local_main_process,
):
if epoch >= config.total_ppo_epochs:
break
question_tensors = batch["input_ids"]
query_response_tensors = ppo_trainer.generate(
question_tensors,
return_prompt=True,
length_sampler=output_length_sampler,
**generation_kwargs,
)
response_tensors = [tensor[len(question) :] for tensor, question in zip(query_response_tensors, question_tensors)]
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
reward_inputs = tokenizer(
texts, padding=True, truncation=True, return_tensors="pt", return_token_type_ids=False
).to(ppo_trainer.accelerator.device)
# for tensor in reponse_tensors:
for i, tensor in enumerate(query_response_tensors):
if not torch.equal(tensor, reward_inputs["input_ids"][i][: reward_inputs["attention_mask"][i].sum()]):
#TODO
import pdb
pdb.set_trace()
# if script_args.reward_model_name:
pipe_outputs = reward_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[0]["score"]) for output in pipe_outputs]
# else:
# raw_rewards = ppo_trainer.compute_reward_model_score(**reward_inputs)
# rewards = [(raw_rewards[i] - script_args.reward_baseline) for i in range(len(raw_rewards))]
# Run PPO step
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)
# if ema is not None:
# ema.update()
#
# if script_args.reset_freq and epoch and epoch % script_args.reset_freq == 0:
# ema.copy_to()
# ema.load_state_dict(initial_state_dict)
# ppo_trainer.accelerator.print("elastic reset")
if script_args.eval_freq and epoch % script_args.eval_freq == 0:
if ppo_trainer.accelerator.is_main_process:
pipe_outputs = gold_reward_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[0]["score"]) for output in pipe_outputs]
logs = {}
logs["env/gold_reward_mean"] = torch.mean(rewards).cpu().numpy().item()
logs["env/gold_reward_std"] = torch.std(rewards).cpu().numpy().item()
logs["env/gold_reward_dist"] = rewards.cpu().numpy()
ppo_trainer.accelerator.log(logs)
print(logs)
if script_args.save_freq and epoch and epoch % script_args.save_freq == 0:
ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}")