Torch shared tensors
TL;DR
Using specific functions, which should work in most cases for you. This is not without side effects.
from safetensors.torch import load_model, save_model
save_model(model, "model.safetensors")
# Instead of save_file(model.state_dict(), "model.safetensors")
load_model(model, "model.safetensors")
# Instead of model.load_state_dict(load_file("model.safetensors"))
What are shared tensors ?
Pytorch uses shared tensors for some computation. This is extremely interesting to reduce memory usage in general.
One very classic use case is in transformers the embeddings
are shared with
lm_head
. By using the same matrix, the model uses less parameters, and gradients
flow much better to the embeddings
(which is the start of the model, so they don’t
flow easily there, whereas lm_head
is at the tail of the model, so gradients are
extremely good over there, since they are the same tensors, they both benefit)
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(100, 100)
self.b = self.a
def forward(self, x):
return self.b(self.a(x))
model = Model()
print(model.state_dict())
# odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])
torch.save(model.state_dict(), "model.bin")
# This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer
Why are shared tensors not saved in safetensors ?
Multiple reasons for that:
Not all frameworks support them for instance
tensorflow
does not. So if someone saves shared tensors in torch, there is no way to load them in a similar fashion so we could not keep the sameDict[str, Tensor]
API.It makes lazy loading very quickly. Lazy loading is the ability to load only some tensors, or part of tensors for a given file. This is trivial to do without sharing tensors but with tensor sharing
with safe_open("model.safetensors", framework="pt") as f: a = f.get_tensor("a") b = f.get_tensor("b")
Now it’s impossible with this given code to “reshare” buffers after the fact. Once we give the
a
tensor we have no way to give back the same memory when you ask forb
. (In this particular example we could keep track of given buffers but this is not the case in general, since you could do arbitrary work witha
like sending it to another device before asking forb
)It can lead to much larger file than necessary. If you are saving a shared tensor which is only a fraction of a larger tensor, then saving it with pytorch leads to saving the entire buffer instead of saving just what is needed.
a = torch.zeros((100, 100)) b = a[:1, :] torch.save({"b": b}, "model.bin") # File is 41k instead of the expected 400 bytes # In practice it could happen that you save several 10GB instead of 1GB.
Now with all those reasons being mentioned, nothing is set in stone in there. Shared tensors do not cause unsafety, or denial of service potential, so this decision could be revisited if current workarounds are not satisfactory.
How does it work ?
The design is rather simple. We’re going to look for all shared tensors, then looking for all tensors covering the entire buffer (there can be multiple such tensors). That gives us multiple names which can be saved, we simply choose the first one
During load_model
, we are loading a bit like load_state_dict
does, except
we’re looking into the model itself, to check for shared buffers, and ignoring
the “missed keys” which were actually covered by virtue of buffer sharing (they
were properly loaded since there was a buffer that loaded under the hood).
Every other error is raised as-is
Caveat: This means we’re dropping some keys within the file. meaning if you’re
checking for the keys saved on disk, you will see some “missing tensors” or if you’re
using load_state_dict
. Unless we start supporting shared tensors directly in
the format there’s no real way around it.