Instantiating a big model
When you want to use a very big pretrained model, one challenge is to minimize the use of the RAM. The usual workflow from PyTorch is:
- Create your model with random weights.
- Load your pretrained weights.
- Put those pretrained weights in your random model.
Step 1 and 2 both require a full version of the model in memory, which is not a problem in most cases, but if your model starts weighing several GigaBytes, those two copies can make you got our of RAM. Even worse, if you are using torch.distributed
to launch a distributed training, each process will load the pretrained model and store these two copies in RAM.
Note that the randomly created model is initialized with “empty” tensors, which take the space in memory without filling it (thus the random values are whatever was in this chunk of memory at a given time). The random initialization following the appropriate distribution for the kind of model/parameters instatiated (like a normal distribution for instance) is only performed after step 3 on the non-initialized weights, to be as fast as possible!
In this guide, we explore the solutions Transformers offer to deal with this issue. Note that this is an area of active development, so the APIs explained here may change slightly in the future.
Sharded checkpoints
Since version 4.18.0, model checkpoints that end up taking more than 10GB of space are automatically sharded in smaller pieces. In terms of having one single checkpoint when you do model.save_pretrained(save_dir)
, you will end up with several partial checkpoints (each of which being of size < 10GB) and an index that maps parameter names to the files they are stored in.
You can control the maximum size before sharding with the max_shard_size
parameter, so for the sake of an example, we’ll use a normal-size models with a small shard size: let’s take a traditional BERT model.
from transformers import AutoModel
model = AutoModel.from_pretrained("bert-base-cased")
If you save it using save_pretrained(), you will get a new folder with two files: the config of the model and its weights:
>>> import os
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir)
... print(sorted(os.listdir(tmp_dir)))
['config.json', 'pytorch_model.bin']
Now let’s use a maximum shard size of 200MB:
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... print(sorted(os.listdir(tmp_dir)))
['config.json', 'pytorch_model-00001-of-00003.bin', 'pytorch_model-00002-of-00003.bin', 'pytorch_model-00003-of-00003.bin', 'pytorch_model.bin.index.json']
On top of the configuration of the model, we see three different weights files, and an index.json
file which is our index. A checkpoint like this can be fully reloaded using the from_pretrained() method:
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... new_model = AutoModel.from_pretrained(tmp_dir)
The main advantage of doing this for big models is that during step 2 of the workflow shown above, each shard of the checkpoint is loaded after the previous one, capping the memory usage in RAM to the model size plus the size of the biggest shard.
Beind the scenes, the index file is used to determine which keys are in the checkpoint, and where the corresponding weights are stored. We can load that index like any json and get a dictionary:
>>> import json
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... with open(os.path.join(tmp_dir, "pytorch_model.bin.index.json"), "r") as f:
... index = json.load(f)
>>> print(index.keys())
dict_keys(['metadata', 'weight_map'])
The metadata just consists of the total size of the model for now. We plan to add several other informations in the future:
>>> index["metadata"]
{'total_size': 433245184}
The weights map is the main part of this index, which maps each parameter name (as usually found in a PyTorch model state_dict
) to the file it’s stored in:
>>> index["weight_map"]
{'embeddings.LayerNorm.bias': 'pytorch_model-00001-of-00003.bin',
'embeddings.LayerNorm.weight': 'pytorch_model-00001-of-00003.bin',
...
If you want to directly load such a sharded checkpoint inside a model without using from_pretrained() (like you would do model.load_state_dict()
for a full checkpoint) you should use load_sharded_checkpoint():
>>> from transformers.modeling_utils import load_sharded_checkpoint
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... load_sharded_checkpoint(model, tmp_dir)
Low memory loading
Sharded checkpoints reduce the memory usage during step 2 of the workflow mentioned above, but in order to use that model in a low memory setting, we recommend leveraging our tools based on the Accelerate library.
Please read the following guide for more information: Large model loading using Accelerate