Callbacks
SyncRefModelCallback
RichProgressCallback
A TrainerCallback
that displays the progress of training or evaluation using Rich.
WinRateCallback
class trl.WinRateCallback
< source >( judge: BasePairwiseJudge trainer: Trainer generation_config: Optional = None num_prompts: Optional = None shuffle_order: bool = True )
Parameters
- judge (
BasePairwiseJudge
) — The judge to use for comparing completions. - trainer (
Trainer
) — Trainer to which the callback will be attached. The trainer’s evaluation dataset must include a"prompt"
column containing the prompts for generating completions. If theTrainer
has a reference model (via theref_model
attribute), it will use this reference model for generating the reference completions; otherwise, it defaults to using the initial model. - generation_config (
GenerationConfig
, optional) — The generation config to use for generating completions. - num_prompts (
int
orNone
, optional, defaults toNone
) — The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset. - shuffle_order (
bool
, optional, defaults toTrue
) — Whether to shuffle the order of the completions before judging.
A TrainerCallback that computes the win rate of a model based on a reference.
It generates completions using prompts from the evaluation dataset and compares the trained model’s outputs against
a reference. The reference is either the initial version of the model (before training) or the reference model, if
available in the trainer. During each evaluation step, a judge determines how often the trained model’s completions
win against the reference using a judge. The win rate is then logged in the trainer’s logs under the key
"eval_win_rate"
.
LogCompletionsCallback
class trl.LogCompletionsCallback
< source >( trainer: Trainer generation_config: Optional = None num_prompts: Optional = None freq: Optional = None )
Parameters
- trainer (
Trainer
) — Trainer to which the callback will be attached. The trainer’s evaluation dataset must include a"prompt"
column containing the prompts for generating completions. - generation_config (
GenerationConfig
, optional) — The generation config to use for generating completions. - num_prompts (
int
orNone
, optional) — The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset. - freq (
int
orNone
, optional) — The frequency at which to log completions. If not provided, defaults to the trainer’seval_steps
.
A TrainerCallback that logs completions to Weights & Biases.