Accelerate documentation

Utilities for Fully Sharded Data Parallelism

You are viewing v0.31.0 version. A newer version v1.1.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Utilities for Fully Sharded Data Parallelism

accelerate.utils.merge_fsdp_weights

< >

( checkpoint_dir: str output_path: str safe_serialization: bool = True remove_checkpoint_dir: bool = False )

Parameters

  • checkpoint_dir (str) — The directory containing the FSDP checkpoints (can be either the model or optimizer).
  • output_path (str) — The path to save the merged checkpoint.
  • safe_serialization (bool, optional, defaults to True) — Whether to save the merged weights with safetensors (recommended).
  • remove_checkpoint_dir (bool, optional, defaults to False) — Whether to remove the checkpoint directory after merging.

Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if SHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors if safe_serialization else pytorch_model.bin.

Note: this is a CPU-bound process.

class accelerate.FullyShardedDataParallelPlugin

< >

( sharding_strategy: typing.Any = None backward_prefetch: typing.Any = None mixed_precision_policy: typing.Any = None auto_wrap_policy: Optional = None cpu_offload: typing.Any = None ignored_modules: Optional = None state_dict_type: typing.Any = None state_dict_config: typing.Any = None optim_state_dict_config: typing.Any = None limit_all_gathers: bool = True use_orig_params: bool = True param_init_fn: Optional = None sync_module_states: bool = True forward_prefetch: bool = False activation_checkpointing: bool = False )

This plugin is used to enable fully sharded data parallelism.

< > Update on GitHub