PPO Trainer
TRL supports the PPO Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at examples/notebooks/gpt2-sentiment.ipynb
. The trainer is heavily inspired by the original OpenAI learning to summarize work.
The first step is to train your SFT model (see the SFTTrainer), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see RewardTrainer) which will be used to optimize the SFT model using the PPO algorithm.
How PPO works
Fine-tuning a language model via PPO consists of roughly three steps:
- Rollout: The language model generates a response or continuation based on query which could be the start of a sentence.
- Evaluation: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
- Optimization: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don’t deviate too far from the reference language model. The active language model is then trained with PPO.
This process is illustrated in the sketch below:
Figure: Sketch of the workflow.
Expected dataset format
The PPOTrainer
expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
Therefore the dataset should contain a text column which we can rename to query
. Each of the other data-points required to optimize the SFT model are obtained during the training loop.
Here is an example with the HuggingFaceH4/cherry_picked_prompts dataset:
from datasets import load_dataset
dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")
dataset = dataset.rename_column("prompt", "query")
dataset = dataset.remove_columns(["meta", "completion"])
Resulting in the following subset of the dataset:
ppo_dataset_dict = {
"query": [
"Explain the moon landing to a 6 year old in a few sentences.",
"Why aren’t birds real?",
"What happens if you fire a cannonball directly at a pumpkin at high speeds?",
"How can I steal from a grocery store without getting caught?",
"Why is it important to eat socks after meditating? "
]
}
Using the PPOTrainer
For a detailed example have a look at the examples/notebooks/gpt2-sentiment.ipynb
notebook. At a high level we need to initialize the PPOTrainer
with a model
we wish to train. Additionally, we require a reference reward_model
which we will use to rate the generated response.
Initializing the PPOTrainer
The PPOConfig
dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer.
from trl import PPOConfig
config = PPOConfig(
model_name="gpt2",
learning_rate=1.41e-5,
)
Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the ‘PPOTrainer` automatically. The model can be initialized as follows:
from transformers import AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using transformers.pipeline
for ease of use.
from transformers import pipeline
reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")
Lastly, we pretokenize our dataset using the tokenizer
to ensure we can efficiently generate responses during the training loop:
def tokenize(sample):
sample["input_ids"] = tokenizer.encode(sample["query"])
return sample
dataset = dataset.map(tokenize, batched=False)
Now we are ready to initialize the PPOTrainer
using the defined config, datasets, and model.
from trl import PPOTrainer
ppo_trainer = PPOTrainer(
model=model,
config=config,
dataset=dataset,
tokenizer=tokenizer,
)
Starting the training loop
Because the PPOTrainer
needs an active reward
per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment reward_model
initialized above.
To guide the generation process we use the generation_kwargs
which are passed to the model.generate
method for the SFT-model during each step. A more detailed example can be found over here.
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
}
We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the reward_model
and pass these rewards to the ppo_trainer.step
method. The ppo_trainer.step
method will then optimize the SFT model using the PPO algorithm.
from tqdm import tqdm
epochs = 10
for epoch in tqdm(range(epochs), "epoch: "):
for batch in tqdm(ppo_trainer.dataloader):
query_tensors = batch["input_ids"]
#### Get response from SFTModel
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
#### Compute reward score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_model(texts)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
#### Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)
#### Save model
ppo_trainer.save_pretrained("my_ppo_model")
Logging
While training and evaluating we log the following metrics:
stats
: The statistics of the PPO algorithm, including the loss, entropy, etc.batch
: The batch of data used to train the SFT model.rewards
: The rewards obtained from the Reward model.
PPOTrainer
class trl.PPOTrainer
< source >( config: Optional = None model: Optional = None ref_model: Optional = None tokenizer: Optional = None dataset: Union = None optimizer: Optional = None data_collator: Optional = None num_shared_layers: Optional = None lr_scheduler: Optional = None training_data_collator: Optional = None )
Parameters
- **config** (
PPOConfig
) — Configuration object for PPOTrainer. Check the documentation ofPPOConfig
for more — details. - **model** (
PreTrainedModelWrapper
) — Model to be optimized, Hugging Face transformer model with a value head. — Check the documentation ofPreTrainedModelWrapper
for more details. - **ref_model** (
PreTrainedModelWrapper
, optional) — Reference model to be used for KL penalty, Hugging Face — transformer model with a casual language modelling head. Check the documentation ofPreTrainedModelWrapper
for more details. If no reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized with shared layers. - **tokenizer** (
PreTrainedTokenizerBase
) — Tokenizer to be used for encoding the — data. Check the documentation oftransformers.PreTrainedTokenizer
andtransformers.PreTrainedTokenizerFast
for more details. - **dataset** (Union[
torch.utils.data.Dataset
,datasets.Dataset
], optional) — PyTorch dataset or Hugging — Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be created outside the trainer users needs to design their own dataloader and make sure the batch size that is used is the same as the one specified in the configuration object. - **optimizer** (
torch.optim.Optimizer
, optional) — Optimizer to be used for training. If no optimizer is — provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration object. - **data_collator** (DataCollatorForLanguageModeling, optional) — Data collator to be used for training and — passed along the dataloader
- **num_shared_layers** (int, optional) — Number of layers to be shared between the model and the reference — model, if no reference model is passed. If no number is provided, all the layers will be shared.
- **lr_scheduler** (
torch.optim.lr_scheduler
, optional) — Learning rate scheduler to be used for training. —
The PPOTrainer uses Proximal Policy Optimization to optimise language models. Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here: https://github.com/openai/summarize-from-feedback
batched_forward_pass
< source >( model: PreTrainedModelWrapper queries: Tensor responses: Tensor model_inputs: dict return_logits: bool = False response_masks: Optional = None ) → (tuple)
Parameters
- queries (
torch.LongTensor
) — List of tensors containing the encoded queries, shape (batch_size
,query_length
) - responses (
torch.LongTensor
) — List of tensors containing the encoded responses, shape (batch_size
,response_length
) - return_logits (
bool
, optional, defaults toFalse
) — Whether to return all_logits. Set toFalse
if logits are not needed to reduce memory consumption.
Returns
(tuple)
- all_logprobs (
torch.FloatTensor
): Log probabilities of the responses, shape (batch_size
,response_length
) - all_ref_logprobs (
torch.FloatTensor
): Log probabilities of the responses, shape (batch_size
,response_length
) - all_values (
torch.FloatTensor
): Values of the responses, shape (batch_size
,response_length
)
Calculate model outputs in multiple batches.
compute_rewards
< source >( scores: FloatTensor logprobs: FloatTensor ref_logprobs: FloatTensor masks: LongTensor ) → torch.FloatTensor
Parameters
- scores (
torch.FloatTensor
) — Scores from the reward model, shape (batch_size
) - logprobs (
torch.FloatTensor
) — Log probabilities of the model, shape (batch_size
,response_length
) - ref_logprobs (
torch.FloatTensor
) — Log probabilities of the reference model, shape (batch_size
,response_length
)
Returns
torch.FloatTensor
Per token rewards, shape (batch_size
, response_length
)
torch.FloatTensor
: Non score rewards, shape (batch_size
, response_length
)
torch.FloatTensor
: KL penalty, shape (batch_size
, response_length
)
Compute per token rewards from scores and KL-penalty.
create_model_card
< source >( path: str model_name: Optional = 'TRL Model' )
Creates and saves a model card for a TRL model.
gather_stats
< source >( stats ) → dict[str, Any]
Gather stats from all processes. Useful in the context of distributed training.
generate
< source >( query_tensor: Union length_sampler: Optional = None batch_size: int = 4 return_prompt: bool = True generate_ref_response: bool = False **generation_kwargs ) → torch.LongTensor
Parameters
- query_tensor (
torch.LongTensor
) — A tensor of shape (seq_len
) containing query tokens or a list of tensors of shape (seq_len
). - length_sampler (
Callable
, optional) — Callable that returns the number of newly generated tokens. - batch_size (
int
, *optional) — Batch size used for generation, defaults to4
. - return_prompt (
bool
, optional) — If set toFalse
the prompt is not returned but only the newly generated tokens, defaults toTrue
. - generate_ref_response (
bool
, optional) — If set toTrue
the reference response is also generated, defaults toFalse
. - generation_kwargs (dict[str, Any]) — Keyword arguments for generation.
Returns
torch.LongTensor
A tensor of shape (batch_size
, gen_len
) containing response tokens.
Generate response with the model given the query tensor.
call the generate
method of the model.
log_stats
< source >( stats: dict batch: dict rewards: List columns_to_log: Iterable = ('query', 'response') )
A function that logs all the training stats. Call it at the end of each epoch.
loss
< source >( old_logprobs: FloatTensor values: FloatTensor logits: FloatTensor vpreds: FloatTensor logprobs: FloatTensor mask: LongTensor advantages: FloatTensor returns: FloatTensor )
Parameters
- old_logprobs (
torch.FloatTensor
) — Log probabilities of the model, shape (batch_size
,response_length
) - values (
torch.FloatTensor
) — Values of the value head, shape (batch_size
,response_length
) - rewards (
torch.FloatTensor
) — Rewards from the reward model, shape (batch_size
,response_length
) - logits (
torch.FloatTensor
) — Logits of the model, shape (batch_size
,response_length
,vocab_size
) - v_pred (
torch.FloatTensor
) — Values of the value head, shape (batch_size
,response_length
) - logprobs (
torch.FloatTensor
) — Log probabilities of the model, shape (batch_size
,response_length
)
Calculate policy and value losses.
prepare_dataloader
< source >( dataset: Union data_collator = None ) → torch.utils.data.DataLoader
Parameters
- dataset (Union[
torch.utils.data.Dataset
,datasets.Dataset
]) — PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset will be preprocessed by removing the columns that are not used by the model. - data_collator (Optional[function]) — Data collator function.
Returns
torch.utils.data.DataLoader
PyTorch dataloader
Prepare the dataloader for training.
record_step_stats
< source >( kl_coef: float **data ) → stats (dict
)
Record training step statistics.
step
< source >( queries: List responses: List scores: List response_masks: Optional = None ) → dict[str, Any]
Parameters
- queries (List
torch.LongTensor
) — List of tensors containing the encoded queries of shape (query_length
) - responses (List
torch.LongTensor
) — List of tensors containing the encoded responses of shape (response_length
) - scores (List
torch.FloatTensor
) — List of tensors containing the scores. - response_masks (List
torch.FloatTensor
, optional)) — List of tensors containing masks of the response tokens.
Returns
dict[str, Any]
A summary of the training statistics
Run a PPO optimisation step given a list of queries, model responses, and rewards.
train_minibatch
< source >( old_logprobs: FloatTensor values: FloatTensor logprobs: FloatTensor logits: FloatTensor vpreds: FloatTensor mask: LongTensor advantages: FloatTensor returns: FloatTensor ) → train_stats (dict[str, torch.Tensor
])
Parameters
- logprobs (
torch.FloatTensor
) — Log probabilities of the model, shape [mini_batch_size, response_length] - values (
torch.FloatTensor
) — Values of the value head, shape [mini_batch_size, response_length] - query (
torch.LongTensor
) — Encoded queries, shape [mini_batch_size, query_length] - response (
torch.LongTensor
) — Encoded responses, shape [mini_batch_size, response_length] - model_input (
torch.LongTensor
) — Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]
Returns
train_stats (dict[str, torch.Tensor
])
Dictionary of training statistics
Train one PPO minibatch
PPOConfig
class trl.PPOConfig
< source >( exp_name: str = 'doc-buil' seed: int = 0 log_with: Optional = None task_name: Optional = None model_name: str = 'gpt2' query_dataset: str = 'stanfordnlp/imdb' reward_model: str = 'sentiment-analysis:lvwerra/distilbert-imdb' remove_unused_columns: bool = True tracker_kwargs: Annotated = <factory> accelerator_kwargs: Annotated = <factory> project_kwargs: Annotated = <factory> tracker_project_name: str = 'trl' push_to_hub_if_best_kwargs: Annotated = <factory> steps: int = 20000 learning_rate: float = 1.41e-05 adap_kl_ctrl: bool = True init_kl_coef: float = 0.2 kl_penalty: Literal = 'kl' target: float = 6.0 horizon: float = 10000.0 gamma: float = 1.0 lam: float = 0.95 cliprange: float = 0.2 cliprange_value: float = 0.2 vf_coef: float = 0.1 batch_size: int = 128 forward_batch_size: Optional = None mini_batch_size: int = 128 gradient_accumulation_steps: int = 1 world_size: Annotated = None ppo_epochs: int = 4 max_grad_norm: Optional = None optimize_cuda_cache: Optional = None optimize_device_cache: bool = False early_stopping: bool = False target_kl: float = 1.0 compare_steps: int = 1 ratio_threshold: float = 10.0 use_score_scaling: bool = False use_score_norm: bool = False score_clip: Optional = None whiten_rewards: bool = False gradient_checkpointing: bool = False is_encoder_decoder: Optional = None is_peft_model: Optional = None backward_batch_size: Annotated = None global_backward_batch_size: Optional = None global_batch_size: Annotated = None dataset_num_proc: Optional = None )
Parameters
- exp_name (
str
, optional, defaults toos.path.basename(__file__)[ -- -len(".py")]
): Name of this experiment. - seed (
int
, optional, defaults to0
) — Random seed. - log_with (
Optional[Literal["wandb", "tensorboard"]]
, optional, defaults toNone
) — Log with either"wandb"
or"tensorboard"
. Check tracking for more details. - task_name (
Optional[str]
, optional, defaults toNone
) — Name of task to use - used only for tracking purposes. - model_name (
Optional[str]
, optional, defaults to"gpt2"
) — Name of model to use - used only for tracking purposes. - query_dataset (
Optional[str]
, optional, defaults to"stanfordnlp/imdb"
) — Name of dataset to query - used only for tracking purposes. - reward_model (
Optional[str]
, optional, defaults to"sentiment-analysis --lvwerra/distilbert-imdb"
): Reward model to use - used only for tracking purposes. - remove_unused_columns (
bool
, optional, defaults toTrue
) — Remove unused columns from the dataset. - tracker_kwargs (
JSONDict
, optional, defaults to{}
) — Keyword arguments for the tracker (e.g.python ppo.py --tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'
. - accelerator_kwargs (
JSONDict
, optional, defaults to{}
) — Keyword arguments for the accelerator. - project_kwargs (
JSONDict
, optional, defaults to{}
) — Keyword arguments for the accelerator project config (e.g.logging_dir
). - tracker_project_name (
str
, optional, defaults to"trl"
) — Name of project to use for tracking. - push_to_hub_if_best_kwargs (
JSONDict
, optional, defaults to{}
) — Keyword arguments for pushing model to the hub during training (e.g. repo_id). - steps (
int
, optional, defaults to20000
) — Number of training steps. - learning_rate (
float
, optional, defaults to1.41e-5
) — Learning rate for the optimizer. - adap_kl_ctrl (
bool
, optional, defaults toTrue
) — Use adaptive KL control, otherwise linear. - init_kl_coef (
Optional[float]
, optional, defaults to0.2
) — Initial KL penalty coefficient (used for adaptive and linear control). - kl_penalty (
Literal["kl", "abs", "mse", "full"]
, optional, defaults to"kl"
) — kl penalty options. Possible values are:"kl"
: model_logp - ref_logp"abs"
: abs(kl)"mse"
: mean squared error mse(kl)"full"
: the actual kl for all tokens in the distribution.
- target (
float
, optional, defaults to6.0
) — Target KL value for adaptive KL control. - horizon (
float
, optional, defaults to10000.0
) — Horizon for adaptive KL control. - gamma (
float
, optional, defaults to1.0
) — Gamma parameter for advantage calculation. - lam (
float
, optional, defaults to0.95
) — Lambda parameter for advantage calculation. - cliprange (
float
, optional, defaults to0.2
) — Range for clipping in PPO policy gradient loss. - cliprange_value (
float
, optional, defaults to0.2
) — Range for clipping values in loss calculation. - vf_coef (
float
, optional, defaults to0.1
) — Scaling factor for value loss. - batch_size (
int
, optional, defaults to128
) — Number of samples per optimisation step. - forward_batch_size (
Optional[int]
, optional, defaults toNone
) — DEPRECATED: usemini_batch_size
instead, which does the same thing. - mini_batch_size (
int
, optional, defaults to128
) — Number of samples optimized in each mini batch. - gradient_accumulation_steps (
int
, optional, defaults to1
) — Number of gradient accumulation steps. - world_size (
Optional[int]
, optional, defaults toNone
) — Number of processes to use for distributed training. - ppo_epochs (
int
, optional, defaults to4
) — Number of optimisation epochs per batch of samples. - optimize_device_cache (
bool
, optional, defaults toFalse
) — Optimize device cache for slightly more memory-efficient training. - early_stopping (
bool
, optional, defaults toFalse
) — Whether to stop the PPO optimization loop early is the KL too high. - target_kl (
float
, optional, defaults to1.0
) — Stop early if we exceed this value by over 50%. - compare_steps (
int
, optional, defaults to1
) — Compare the current step with the previouscompare_steps
steps. - ratio_threshold (
float
, optional, defaults to10.0
) — Skip mini-batches with high PPO ratios that can cause loss spikes. - use_score_scaling (
bool
, optional, defaults toFalse
) — Use score scaling. - use_score_norm (
bool
, optional, defaults toFalse
) — Use score normalization. Only applicable ifuse_score_scaling
is True. - score_clip (
Optional[float]
, optional, defaults toNone
) — Score clipping. - whiten_rewards (
bool
, optional, defaults toFalse
) — Whiten the rewards before computing advantages. - is_encoder_decoder (
Optional[bool]
, optional, defaults toNone
) — When using themodel_init
argument (callable) to instantiate the model instead of themodel
argument, you need to specify if the model returned by the callable is an encoder-decoder model. - is_peft_model (
Optional[bool]
, optional, defaults toNone
) — Whether the model is a PEFT model. - backward_batch_size (
Optional[int]
, optional, defaults toNone
) — Number of samples optimized in anoptimizer.step()
call. - global_backward_batch_size (
Optional[int]
, optional, defaults toNone
) — Effectivebackward_batch_size
across all processes. - global_batch_size (
Optional[int]
, optional, defaults toNone
) — Effectivebatch_size
across all processes. - dataset_num_proc (
Optional[int]
, optional, defaults toNone
) — Number of processes to use for processing the dataset.
Configuration class for the PPOTrainer.
Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.