Command Line Interfaces (CLIs)
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
Currently supported CLIs are:
trl sft
: fine-tune a LLM on a text/instruction datasettrl dpo
: fine-tune a LLM with DPO on a preference datasettrl chat
: quickly spin up a LLM fine-tuned for chattingtrl env
: get the system information
Fine-tuning with the CLI
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter “text-generation” within models. Also make sure to pick up a relevant dataset for your task.
Before using the sft
or dpo
commands make sure to run:
accelerate config
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of accelerate config
before running any CLI command.
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with trl sft
command.
model_name_or_path:
trl-internal-testing/tiny-random-LlamaForCausalLM
dataset_name:
stanfordnlp/imdb
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
Save that config in a .yaml
and get started immediately! An example CLI config is available as examples/cli_configs/example_config.yaml
. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
Will force-use cosine_with_restarts
for lr_scheduler_type
.
Supported Arguments
We do support all arguments from transformers.TrainingArguments
, for loading your model, we support all arguments from ~trl.ModelConfig
:
class trl.ModelConfig
< source >( model_name_or_path: typing.Optional[str] = None model_revision: str = 'main' torch_dtype: typing.Optional[typing.Literal['auto', 'bfloat16', 'float16', 'float32']] = None trust_remote_code: bool = False attn_implementation: typing.Optional[str] = None use_peft: bool = False lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 lora_target_modules: typing.Optional[typing.List[str]] = None lora_modules_to_save: typing.Optional[typing.List[str]] = None lora_task_type: str = 'CAUSAL_LM' use_rslora: bool = False load_in_8bit: bool = False load_in_4bit: bool = False bnb_4bit_quant_type: typing.Literal['fp4', 'nf4'] = 'nf4' use_bnb_nested_quant: bool = False )
Parameters
- model_name_or_path (
Optional[str]
, optional, defaults toNone
) — Model checkpoint for weights initialization. - model_revision (
str
, optional, defaults to"main"
) — Specific model version to use. It can be a branch name, a tag name, or a commit id. - torch_dtype (
Optional[Literal["auto", "bfloat16", "float16", "float32"]]
, optional, defaults toNone
) — Override the defaulttorch.dtype
and load the model under this dtype. Possible values are"bfloat16"
:torch.bfloat16
"float16"
:torch.float16
"float32"
:torch.float32
"auto"
: Automatically derive the dtype from the model’s weights.
- trust_remote_code (
bool
, optional, defaults toFalse
) — Whether to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - attn_implementation (
Optional[str]
, optional, defaults toNone
) — Which attention implementation to use. You can run--attn_implementation=flash_attention_2
, in which case you must install this manually by runningpip install flash-attn --no-build-isolation
. - use_peft (
bool
, optional, defaults toFalse
) — Whether to use PEFT for training. - lora_r (
int
, optional, defaults to16
) — LoRA R value. - lora_alpha (
int
, optional, defaults to32
) — LoRA alpha. - lora_dropout (
float
, optional, defaults to0.05
) — LoRA dropout. - lora_target_modules (
Optional[Union[str, List[str]]]
, optional, defaults toNone
) — LoRA target modules. - lora_modules_to_save (
Optional[List[str]]
, optional, defaults toNone
) — Model layers to unfreeze & train. - lora_task_type (
str
, optional, defaults to"CAUSAL_LM"
) — Task type to pass for LoRA (use"SEQ_CLS"
for reward modeling). - use_rslora (
bool
, optional, defaults toFalse
) — Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor tolora_alpha/√r
, instead of the original default value oflora_alpha/r
. - load_in_8bit (
bool
, optional, defaults toFalse
) — Whether to use 8 bit precision for the base model. Works only with LoRA. - load_in_4bit (
bool
, optional, defaults toFalse
) — Whether to use 4 bit precision for the base model. Works only with LoRA. - bnb_4bit_quant_type (
str
, optional, defaults to"nf4"
) — Quantization type ("fp4"
or"nf4"
). - use_bnb_nested_quant (
bool
, optional, defaults toFalse
) — Whether to use nested quantization.
Configuration class for the models.
Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.
You can pass any of these arguments either to the CLI or the YAML file.
Supervised Fine-tuning (SFT)
Follow the basic instructions above and run trl sft --output_dir <output_dir> <*args>
:
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
The SFT CLI is based on the examples/scripts/sft.py
script.
Direct Policy Optimization (DPO)
To use the DPO CLI, you need to have a dataset in the TRL format such as
- TRL’s Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
- TRL’s OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
These datasets always have at least three columns prompt, chosen, rejected
:
prompt
is a list of strings.chosen
is the chosen response in chat formatrejected
is the rejected response chat format
To do a quick start, you can run the following command:
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
The DPO CLI is based on the examples/scripts/dpo.py
script.
Custom preference dataset
Format the dataset into TRL format (you can adapt the examples/datasets/anthropic_hh.py
):
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
Chat interface
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
$ trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
<quentin_gallouedec>:
What is the best programming language?
<Qwen/Qwen1.5-0.5B-Chat>:
There isn't a "best" programming language, as everyone has different style preferences, needs, and preferences. However, some people commonly use
languages like Python, Java, C++, and JavaScript, which are popular among developers for a variety of reasons, including readability, flexibility,
and scalability. Ultimately, it depends on personal preference, needs, and goals.
Note that the chat interface relies on the tokenizer’s chat template to format the inputs for the model. Make sure your tokenizer has a chat template defined.
Besides talking to the model there are a few commands you can use:
clear
: clears the current conversation and start a new oneexample {NAME}
: load example named{NAME}
from the config and use it as the user inputset {SETTING_NAME}={SETTING_VALUE};
: change the system prompt or generation settings (multiple settings are separated by a;
).reset
: same as clear but also resets the generation configs to defaults if they have been changed byset
save
orsave {SAVE_NAME}
: save the current chat and settings to file by default to./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml
or{SAVE_NAME}
if providedexit
: closes the interface
The default examples are defined in examples/scripts/config/default_chat_config.yaml
but you can pass your own with --config CONFIG_FILE
where you can also specify the default generation parameters.
Getting the system information
You can get the system information by running the following command:
trl env
This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.
Copy-paste the following information when reporting an issue: - Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31 - Python version: 3.11.9 - PyTorch version: 2.4.1 - CUDA device: NVIDIA H100 80GB HBM3 - Transformers version: 4.45.0.dev0 - Accelerate version: 0.34.2 - Accelerate config: - compute_environment: LOCAL_MACHINE - distributed_type: DEEPSPEED - mixed_precision: no - use_cpu: False - debug: False - num_processes: 4 - machine_rank: 0 - num_machines: 1 - rdzv_backend: static - same_network: True - main_training_function: main - enable_cpu_affinity: False - deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2} - downcast_bf16: no - tpu_use_cluster: False - tpu_use_sudo: False - tpu_env: [] - Datasets version: 3.0.0 - HF Hub version: 0.24.7 - TRL version: 0.12.0.dev0+acb4d70 - bitsandbytes version: 0.41.1 - DeepSpeed version: 0.15.1 - Diffusers version: 0.30.3 - Liger-Kernel version: 0.3.0 - LLM-Blender version: 0.0.2 - OpenAI version: 1.46.0 - PEFT version: 0.12.0
This information are required when reporting an issue.
< > Update on GitHub