XPO Trainer
Overview
Exploratory Preference Optimization (XPO) was proposed in the paper Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, Corby Rosset, Ahmed Awadallah, and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the intitial model and human feedback data.
The abstract from the paper is the following:
Reinforcement learning from human feedback (RLHF) has emerged as a central tool for language model alignment. We consider online exploration in RLHF, which exploits interactive access to human or AI feedback by deliberately encouraging the model to produce diverse, maximally informative responses. By allowing RLHF to confidently stray from the pre-trained model, online exploration offers the possibility of novel, potentially super-human capabilities, but its full potential as a paradigm for language model training has yet to be realized, owing to computational and statistical bottlenecks in directly adapting existing reinforcement learning techniques. We propose a new algorithm for online exploration in RLHF, Exploratory Preference Optimization (XPO), which is simple and practical — a one-line change to (online) Direct Preference Optimization (DPO; Rafailov et al., 2023) — yet enjoys the strongest known provable guarantees and promising empirical performance. XPO augments the DPO objective with a novel and principled exploration bonus, empowering the algorithm to explore outside the support of the initial model and human feedback data. In theory, we show that XPO is provably sample-efficient and converges to a near-optimal language model policy under natural exploration conditions, irrespective of whether the initial model has good coverage. Our analysis, which builds on the observation that DPO implicitly performs a form of Q*-approximation (or, Bellman error minimization), combines previously disparate techniques from language modeling and theoretical reinforcement learning in a serendipitous fashion through the perspective of KL-regularized Markov decision processes. Empirically, we find that XPO is more sample-efficient than non-exploratory DPO variants in a preliminary evaluation.
This post-training method was contributed by Kashif Rasul, Quentin Gallouédec and Lewis Tunstall.
Quick start
This example demonstrates how to train a model using the XPO method. We use the Qwen 0.5B model as the base model and PairRMJudge as a judge. We use the prompts from the UltraFeedback dataset. You can view the prompts in the dataset here:
Below is the script to train the model:
# train_xpo.py
from datasets import load_dataset
from trl import PairRMJudge, XPOConfig, XPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = XPOConfig(output_dir="Qwen2-0.5B-XPO", logging_steps=10)
trainer = XPOTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()
Execute the script using the following command:
accelerate launch train_xpo.py
Distributed across 8 GPUs, the training takes approximately 1 hour.
To see how the trained model performs, you can use the TRL Chat CLI.
$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-XPO
<quentin_gallouedec>:
What is the best programming language?
<trl-lib/Qwen2-0.5B-XPO>:
The best programming language depends on individual preferences and familiarity with coding concepts. Some popular languages include Python, Java, C++, and JavaScript.
Expected dataset type
XPO requires a prompt-only dataset. The XPOTrainer supports both conversational and standard dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Usage tips
Use a reward model
Instead of a judge, you can chose to use a reward model — see Reward Bench for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the trl-lib/Qwen2-0.5B-Reward model:
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification
- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
trainer = XPOTrainer(
...
- judge=judge,
+ reward_model=reward_model,
)
Make sure that the SFT model and reward model use the same chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
Encourage EOS token generation
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the max_new_tokens
argument of XPOConfig. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the missing_eos_penalty
argument of XPOConfig:
training_args = XPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
Logging Completions
To better understand your model’s behavior during training, you can log sample completions periodically using the LogCompletionsCallback.
trainer = XPOTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)
This callback logs the model’s generated completions directly to Weights & Biases.
Example script
We provide an example script to train a model using the XPO method. The script is available in examples/scripts/xpo.py
To test the XPO script with the Qwen2.5 0.5B model on the UltraFeedback dataset, run the following command:
python examples/scripts/xpo.py \ --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --judge pair_rm \ --dataset_name trl-lib/ultrafeedback-prompt \ --learning_rate 5.0e-7 \ --logging_steps 25 \ --output_dir Qwen2.5-0.5B-XPO-PairRM \ --warmup_ratio 0.1 \ --push_to_hub
Logged metrics
The logged metrics are as follows:
loss/xpo
: The mean xpo part of the full loss.loss/dpo
: The mean dpo part of the full loss.objective/kl
: The mean KL divergence between the model and reference data.objective/entropy
: The mean entropy of the model and reference data.objective/model_scores
: The mean scores (according to the reward model) of the model completions.objective/ref_scores
: The mean scores (according to the reward model) of the reference completions.objective/scores_margin
: The mean score margin (according to the external reward model) between the chosen and rejected completions.rewards/chosen
: The mean reward (according to XPO’s DPO implicit reward model) of the chosen completions.rewards/rejected
: The mean reward (according to XPO’s DPO implicit reward model) of the rejected completions.rewards/accuracies
: The accuracies of the XPO’s implicit reward model.rewards/margins
: The mean reward margin (according to online DPO’s implicit reward model) between the chosen and rejected completions.logps/chosen
: The mean log probabilities of the chosen completions.logps/rejected
: The mean log probabilities of the rejected completions.val/model_contain_eos_token
: The amount of times the model’s output contains the eos token.val/ref_contain_eos_token
: The amount of times the reference’s output contains the eos token.alpha
: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to XPOConfig.beta
: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to XPOConfig.
XPOTrainer
class trl.XPOTrainer
< source >( model: Union = None ref_model: Union = None reward_model: Optional = None judge: Optional = None args: Optional = None data_collator: Optional = None train_dataset: Union = None eval_dataset: Union = None processing_class: Union = None peft_config: Optional = None compute_metrics: Optional = None callbacks: Optional = None optimizers: Tuple = (None, None) preprocess_logits_for_metrics: Optional = None )
Parameters
- model (
transformers.PreTrainedModel
) — The model to train, preferably anAutoModelForCausalLM
. - ref_model (
PreTrainedModelWrapper
) — Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. - reward_model (
transformers.PreTrainedModel
) — The reward model to score completions with, preferably anAutoModelForSequenceClassification
. - judge (
BasePairwiseJudge
) — The judge to use for pairwise comparison of model completions. - args (
XPOConfig
) — The XPO config arguments to use for training. - data_collator (
transformers.DataCollator
) — The data collator to use for training. If None is specified, the default data collator (DPODataCollatorWithPadding
) will be used which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. - train_dataset (
datasets.Dataset
) — The dataset to use for training. - eval_dataset (
datasets.Dataset
) — The dataset to use for evaluation. - processing_class (
PreTrainedTokenizerBase
orBaseImageProcessor
orFeatureExtractionMixin
orProcessorMixin
, optional) — Processing class used to process the data. If provided, will be used to automatically process the inputs for the model, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model. - peft_config (
Dict
) — The peft config to use for training. - compute_metrics (
Callable[[EvalPrediction], Dict]
, optional) — The function to use to compute the metrics. Must take aEvalPrediction
and return a dictionary string to metric values. - callbacks (
List[transformers.TrainerCallback]
) — The callbacks to use for training. - optimizers (
Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — The function to use to preprocess the logits before computing the metrics.
Initialize XPOTrainer as a subclass of OnlineDPOConfig.
XPOConfig
class trl.XPOConfig
< source >( output_dir: str overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: Union = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: Optional = None per_gpu_eval_batch_size: Optional = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: Optional = None eval_delay: Optional = 0 torch_empty_cache_steps: Optional = None learning_rate: float = 5e-07 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: Union = 'linear' lr_scheduler_kwargs: Union = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: Optional = 'passive' log_level_replica: Optional = 'warning' log_on_each_node: bool = True logging_dir: Optional = None logging_strategy: Union = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: Union = 'steps' save_steps: float = 500 save_total_limit: Optional = None save_safetensors: Optional = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: Optional = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: Optional = None local_rank: int = -1 ddp_backend: Optional = None tpu_num_cores: Optional = None tpu_metrics_debug: bool = False debug: Union = '' dataloader_drop_last: bool = False eval_steps: Optional = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: Optional = None past_index: int = -1 run_name: Optional = None disable_tqdm: Optional = None remove_unused_columns: Optional = True label_names: Optional = None load_best_model_at_end: Optional = False metric_for_best_model: Optional = None greater_is_better: Optional = None ignore_data_skip: bool = False fsdp: Union = '' fsdp_min_num_params: int = 0 fsdp_config: Union = None fsdp_transformer_layer_cls_to_wrap: Optional = None accelerator_config: Union = None deepspeed: Union = None label_smoothing_factor: float = 0.0 optim: Union = 'adamw_torch' optim_args: Optional = None adafactor: bool = False group_by_length: bool = False length_column_name: Optional = 'length' report_to: Union = None ddp_find_unused_parameters: Optional = None ddp_bucket_cap_mb: Optional = None ddp_broadcast_buffers: Optional = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: Optional = None hub_model_id: Optional = None hub_strategy: Union = 'every_save' hub_token: Optional = None hub_private_repo: bool = False hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: Union = None include_inputs_for_metrics: bool = False include_for_metrics: List = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' evaluation_strategy: Union = None push_to_hub_model_id: Optional = None push_to_hub_organization: Optional = None push_to_hub_token: Optional = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: Optional = None ray_scope: Optional = 'last' ddp_timeout: Optional = 1800 torch_compile: bool = False torch_compile_backend: Optional = None torch_compile_mode: Optional = None dispatch_batches: Optional = None split_batches: Optional = None include_tokens_per_second: Optional = False include_num_input_tokens_seen: Optional = False neftune_noise_alpha: Optional = None optim_target_modules: Union = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: Optional = False eval_use_gather_object: Optional = False average_tokens_across_devices: Optional = False reward_model_path: Optional = None judge: Optional = None max_new_tokens: int = 64 temperature: float = 0.9 missing_eos_penalty: Optional = None beta: List = <factory> loss_type: Literal = 'sigmoid' dataset_num_proc: Optional = None disable_dropout: bool = True alpha: List = <factory> )
Configuration class for the XPOTrainer.
Subclass of OnlineDPOConfig we can use all its arguments and add the following: