FSDP-QLoRA
FSDP-QLoRA combines data parallelism (FSDP enables sharding model parameters, optimizer states, and gradients across GPUs), 4-bit quantization, and LoRA to train LLMs up to 70B parameters on a dual 24GB GPU system. This technique was released by Answer.AI in collaboration with bitsandbytes to make training LLMs more efficient and accessible for everyone.
This guide provides a brief guide on how bitsandbytes supports storing quantized weights to enable FSDP-QLoRA, and how to run training with the Hugging Face libraries.
Other changes required for bitsandbytes to support FSDP-QLoRA, such as reconstructing the weights from the quantization metadata and preventing quantizing already quantized weights when they’re moved from a CPU to GPU, are documented in this Pull Request and described in the Enabling 70B Finetuning on Consumer GPUs blog post. We highly recommend reading these resources for a better understanding of FSDP-QLoRA!
Quantized data storage
FSDP only supports sharding float data types which can be problematic because quantized weights are typically stored as integer data types (uint8). bitsandbytes doesn’t have this problem because it uses StoreChar
to read and write quantized weights regardless of the data type storage. This makes it simple to add a quant_storage
parameter to the Linear4bit and Params4bit classes and set it to torch.uint8
to maintain backward compatibility with the codebase. With the quant_storage
parameter, you can select any of the FSDP supported data types to shard Linear4bit with such as bfloat16, float16 or float32.
You’ll typically access and configure this option from transformers.BitsAndBytesConfig by setting the bnb_4bit_quant_storage
parameter. It is very important the quant_storage
data type matches the data types used throughout the model because FSDP can only wrap layers and modules that have the same floating data type. Making sure the data types are aligned will ensure the model is correctly sharded.
The compute_dtype
is the data type used for computation inside the CUDA kernel, where the 4-bit quantized weights are unpacked from the data type in quant_storage
and dequantized to compute_dtype
. We recommend using torch.bfloat16 (if available on your hardware) for better numerical stability.
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_storage=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
)
Check out this section of the PEFT documentation for the config file and training code to run FSDP-QLoRA training.
Training
FSDP is a distributed training framework that needs to be launched as a distributed training job with a library like Accelerate or torchrun. The launch command provided in this section uses Accelerate to launch the training script.
bitsandbytes is deeply integrated with the Hugging Face ecosystem, making it easy to use with libraries like Transformers, PEFT, and TRL.
PEFT provides a configuration file (fsdp_config_qlora.yaml), launch command (run_peft_qlora_fsdp.sh), and training script (train.py) for running FSDP-QLoRA. To learn more, check out the Use PEFT QLoRA and FSDP for finetuning large models on multiple GPUs documentation. This section briefly covers the steps to run FSDP-QLoRA training.
Before you begin, make sure you have the latest libraries installed.
pip install -U bitsandbytes accelerate transformers peft trl
The important change that enables FSDP-QLoRA training is the bnb_4bit_quant_storage
parameter in the BitsAndBytesConfig class. This allows you to set the storage data type of the quantized weights to a float data type.
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.bfloat16,
)
Pass the BitsAndBytesConfig to a model to set it up for FSDP-QLoRA. You should set the torch_dtype
parameter to match bnb_4bit_quant_storage
so that the Linear4bit layers are wrapped identically to the Linear
layers. If the storage types do not match, then each Linear4bit layer is wrapped individually.
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
)
Configure the ~peft.LoraConfig
class for QLoRA training by setting target_modules="all-linear"
.
from peft import LoraConfig
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)
Now you can pass everything to the SFTTrainer for training.
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
)
trainer.train()
Resources
To learn more about FSDP and QLoRA, check out the following resources:
- The AnswerDotAI/fsdp_qlora repository.
- The introductory You can now train a 70b language model at home blog post by Answer.AI.
- For an introduction to FSDP, read the Introducing PyTorch Fully Sharded Data Parallel (FSDP) API blog post.
- For more details about QLoRA, take a look at the Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA blog post.