8-bit optimizers
With 8-bit optimizers, large models can be finetuned with 75% less GPU memory without losing any accuracy compared to training with standard 32-bit optimizers. The reduced memory requirements means 8-bit optimizers are 4x faster than a standard optimizer, and no hyperparameter tuning is required.
This guide will show you how to use 8-bit optimizers.
8-bit optimizers reduce memory usage and accelerate optimization on a wide range of tasks. However, since 8-bit optimizers only reduce memory proportional to the number of parameters, models that use large amounts of activation memory, such as convolutional networks, don’t really benefit from 8-bit optimizers. 8-bit optimizers are most beneficial for training or finetuning models with many parameters on highly memory-constrained GPUs.
8-bit optimizers are a drop-in replacement for regular optimizers which means they also accept the same arguments as a regular optimizer. For NLP models, it is recommended to use the StableEmbedding class to improve stability and results.
import bitsandbytes as bnb
- adam = torch.optim.Adam(...)
+ adam = bnb.optim.Adam8bit(...)
# recommended for NLP models
- before: torch.nn.Embedding(...)
+ bnb.nn.StableEmbedding(...)
By default, all parameter tensors with less than 4096 elements are kept at 32-bits even if you initialize those parameters with 8-bit optimizers. This is done because small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm).
You can change this value with the min_8bit_size
parameter. For example, if you want to optimize parameters to 8-bits only if the minimum size is 16384 values (it is recommended to use multiples of 4096):
import bitsandbytes as bnb
adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384)
Other parameters you can configure include the learning rate (lr
), the decay rates (betas
), the number of bits of the optimizer state (optim_bits
), and percentile clipping (percentile_clipping
) which can increase stability. For example, to initialize a 32-bit Adam optimizer with 5th percentile clipping:
import bitsandbytes as bnb
adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=32, percentile_clipping=5)
Optimize unstable parameters
To optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, use the GlobalOptimManager class to override the specific hyperparameters for a particular layer. You’ll need to:
- Register the parameters while they’re on the CPU.
import torch
import bitsandbytes as bnb
mng = bnb.optim.GlobalOptimManager.get_instance()
model = MyModel()
mng.register_parameters(model.parameters())
- Override the config with the new desired hyperparameters. For example, let’s override the
model.fc1.weight
layer to use 32-bit Adam.
Check the optimizer API documentation for more information about other hyperparameters you can override.
model = model.cuda()
# use 8-bit optimizer states for all parameters
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
# override the parameter model.fc1.weight now uses 32-bit Adam
mng.override_config(model.fc1.weight, "optim_bits", 32)
You can also override multiple layers at once by passing them as a list and the new hyperparameters as a dictionary. For example, let’s override the model.special.weight
and model.also_special.weight
layers to use sparse optimization and a lower learning and decay rate.
mng.override_config([model.special.weight, model.also_special.weight],
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
For a specific layer, we recommend overriding locally in each module. Pass the module, the parameter, and its attribute name to the GlobalOptimManager:
class MyModule(torch.nn.Module):
def __init__(d_in, d_out):
super(MyModule, self).__init__()
self.linear = torch.nn.Linear(d_in, d_out)
# optimization will happen in 32-bit and
# learning rate will be set to 0.0001 independent of the main learning rate
config = {'optim_bits': 32, 'lr' : 0.0001}
GlobalOptimManager.get_instance().register_module_override(self, 'weight', config)
Next steps
For more conceptual details and explanation about 8-bit optimizers, take a look at the 8-bit optimizers guide.
< > Update on GitHub