What's the usage of `pytorch_model-XXXXX-of-00002.bin`?

#9
by Tunan01 - opened

Thanks for your work!

I followed the tutorial you wrote on GitHub. I put the code in Colab for inference.

Using the audiocraft API example, I downloaded the state_dict.bin file (6.51GB) for the large model, and it worked fine on the T4 and 12GRAM.

However, when using the πŸ€— Transformers Usage, I downloaded the pytorch_model-00001-of-00002.bin, pytorch_model-00002-of-00002.bin, and state_dict.bin, but it caused a RAM crash.

I would like to know the difference between these two approaches. As a beginner, I appreciate your help.

Hey @Tunan01 - when the PyTorch weights exceed a given threshold of size (default = 10GB) we shard (or split) them into multiple smaller files, see https://huggingface.co/docs/transformers/main/main_classes/model#transformers.PreTrainedModel.push_to_hub.max_shard_size

This makes loading large model weights faster, since we can load multiple shards of model weights in parallel. The recommended way of loading model weights with Hugging Face Transformers is using the .from_pretrained method:

from transformers import MusicgenForConditionalGeneration

model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-large")

=> this will take care of loading the sharded weights for you automatically and load the state dict into the MusicGen model class

Thank you for your reply!

But I would like to ask if downloading only state_dict.bin (using audiocraft api example) without pytorch_model-xxxxx-of-00002.bin will not affect the results in a bad way?
Or, Will downloading state_dict.bin with pytorch_model-xxxxx-of-00002.bin (using Hugging Face Transformers) improve the results?

Because these two files(pytorch_model-xxxxx-of-00002.bin) are quite large, I'm not sure how important they are to the results.
Specifically, I would like to know if the weights of musicgen_large model are contained solely within state_dict.bin, or if they are distributed across both state_dict.bin and pytorch_model-xxxxx-of-00002.bin.

Hey @Tunan01 - if using the Audiocraft repository, you just need state_dict.bin (these are all the weights AudioCraft needs to download the MusicGen language model. The text encoder model is downloaded under-the-hood using Transformers' .from_pretrained). If using transformers, you'll need both pytorch_model-xxxxx-of-00002.bin files - these contain the text encoder and language model, hence the larger file size.

Thanks a lot for your patience! I think I get it. πŸ€—

Sign up or log in to comment