Aligning Large Language Models with BRAIn
TL;DR: We introduce BRAIn - a distribution matching approach for RLHF that achieves SOTA performance on Antrophic HH and TL;DR summarization tasks outperforming DPO and other RLHF methods!
Note: This work has been accepted for publication at ICML-2024 (main conference)
Link to arxiv version:: https://arxiv.org/pdf/2402.02479
Table of Contents
- The 3 phases of LLM training
- Alignment to human preferences
- Our contribution - BRAIn
- The BRAIn objective
- Experimental results
The 3 phases of LLM training
In the past few years, large language models (LLMs) have demonstrated immense prowess on a wide-variety of tasks that include multi-turn conversations, creative writing, mathematical and logical reasoning etc. These large language models are often trained in 3 phases:
- Self-supervised pretraining
- Supervised instruction tuning (SFT)
- Reinforcement Learning from Human feedback (RLHF)
The self-supervised pretraining phase induces language understanding and generation capabilities in the model while the supervised instruction tuning phase teaches the model to follow natural language instructions. In the RLHF phase, the model is encouraged to follow behaviours that are desirable to us. The notion of desirability could be explicit (for instance, no profanity in the output) or could be implicit in human preference of certain output text over others.
Alignment to human preferences
So, how does one go about aligning a model to human preferences? PPO-RLHF, the RLHF approach behind GPT-3.5 and GPT-4, achieves this by first training a reward model to mimic the human preferences, that is, the reward should be higher for human-preferred outputs as compared to others. Then, the LLM (also referred to as policy) is finetuned to generate outputs that have high reward as determined by the reward model. We also ensure that the aligned LLM is close to SFT LLM (the supervised instruction tuned LLM), thereby preventing it from forgetting the capabilities that it had acquired in the previous two phases.
PPO-RLHF has lately been replaced by offline contrastive techniques such as Sequence Likelihood Calibration (SLiC), Direct Preference Optimization (DPO) and its variants. These approaches train the LLM to contrast between the preferred/high reward outputs and the rejected/low reward outputs. DPO has emerged as the de-facto methods for aligning high performing models such as Zephyr, Mixtral and LLama-3.
While there are vast dissimilarities between the PPO-RLHF and DPO algorithms, both these approaches have the same final target, referred to as PPO-optimal policy. Another set of less well-known methods use distribution-matching to align the LLM to this optimal policy. Ideally, this would require sampling from the optimal policy which is known to be challenging. Hence distribution matching methods (DPG, GDC, GDC++) sample from a proposal distribution instead and weigh these samples based on their importance weights. Despite the clear intuition behind distribution matching, these methods have not been successful for alignment with human feedback.
Our contribution - BRAIn
While investigating the lack of success of distribution matching methods, we observed that the gradient estimate in distribution matching methods (GDC, GDC++) has high variance. What this means is that the update direction at every time-step varies widely depending on the outputs sampled from the LLM. This is demonstrated below for a simple toy example:
Assume that the target distribution that we are trying to reach is the standard 1D normal distribution . Let the current model distribution be while the proposal distribution be where is varied from to . Below, we plot the variance of the gradient estimate of the different distribution matching objectives with respect to the mean parameter of the model distribution. The samples are drawn from the proposal distribution. As can be observed, the variance of gradient estimates of distribution matching methods (GDC, GDC++) is high when the proposal distribution is not the same as the target distribution.
This investigation is what motivated us to create BRAIn - Bayesian reward-conditioned Augmented Inference, that extends the distribution matching methods as follows:
- We generalize the target distribution in PPO-RLHF, DPO and distribution matching methods by using Bayes' rule to incorporate the reward-modelling assumptions.
- We propose a self-normalized baseline that significantly reduces the variance of the gradient estimate in distribution matching as shown in the figure above. By incorporating the self-normalized baseline, we achieve SOTA performance on TL;DR summarization and Anthropic-Helpful & Harmless response generation tasks, and establish DPO as a special case of BRAIn.
The BRAIn objective
Posterior as target
Given an input prompt , the different RLHF algorithms attempt to reach the target distribution over the set of outputs . This target distribution depends on 2 factors:
- The base distribution. This is often an SFT model, referred to as
- The reward function
BRAIn uses Bayes' rule to combine the information from the above two factors. Specifically, the SFT model acts as the prior while the reward function is used to define a likelihood term. The resulting posterior is referred to as the target .
Training with importance weights
Let be the model that we wish to align to the target . Ideally, one can achieve this by sampling from the target and training using these samples as shown below.
However, since sampling from the target can be challenging, we use a proposal distribution for sampling instead and reweigh those samples based on . Since the normalization constant of is intractable, we self-normalize the weights as shown below:
A note on proposal distribution: What would be the ideal distribution to generate samples from? Clearly, since we are trying to reach the target, ideally, we should sample from the target . However, since this is challenging, we choose to sample from the distribution that is closest to the target. At the beginning of our training, we sample from the SFT model . However, as training proceeds, we include samples from the latest policy.
A self-normalized baseline to reduce variance
The gradient of the above objective is given by This gradient estimate has been used in GDC for LLM alignment with the difference that the weights are not self-normalized. As we had shown earlier, the GDC gradient estimate has high variance which translates to poor performance.
To reduce the variance, we propose to subtract a self-normalized baseline from the above gradient estimate as shown below:
While the connection with distribution matching objective of GDC is obvious, we establish the connection with DPO in the paper.
Experimental results
We evaluate BRAIn on two tasks:
- Summarization: We use the Reddit TL;DR dataset for this task.
- Helpful & Harmless response generation: We use the Anthropic HH dataset for this task.
We evaluate the various models based on win-rate against gold, that is, the fraction of test samples on which the generated response is preferred over the gold response. We compute this quantity using two reward models 1) Train RM which is the reward function used for aligning the SFT model 2) LLM eval in which we prompt Mixtral 8x7B to compare the two outputs and declare a winner. The performance against other baselines are displayed in the figure below:
As can be observed, BRAIn outperforms other baselines on both the evaluation measures.
We also study the impact of self-normalized baseline subtraction on the performance. The table below lists the win-rate of BRAIn with and without self-normalized baseline. The last column of the table corresponds to baseline subtraction without self-normalization. As can be observed from the table, self-normalization is crucial for achieving reasonable performance in distribution matching.
BRAIn | w/o self-norm | w/o baseline | |
---|---|---|---|
TL;DR | 95.2 | 61.4 | 61.1 |
AnthropicHH | 95.4 | 59.1 | 58.3 |
Other Blogs by our team
From Fiction to Fact: Making Chatbots Grounded in Reality
Acknowledgements
This work was done in collaboration with Ramón Fernandez Astudillo, Yatin Nandwani, Tahira Naseem, Mayank Mishra, Guangxuan Xu, Dinesh Raghu, Sachindra Joshi and Asim Munawar