Utilities for Fully Sharded Data Parallelism
accelerate.utils.merge_fsdp_weights
< source >( checkpoint_dir: str output_path: str safe_serialization: bool = True remove_checkpoint_dir: bool = False )
Parameters
- checkpoint_dir (
str
) — The directory containing the FSDP checkpoints (can be either the model or optimizer). - output_path (
str
) — The path to save the merged checkpoint. - safe_serialization (
bool
, optional, defaults toTrue
) — Whether to save the merged weights with safetensors (recommended). - remove_checkpoint_dir (
bool
, optional, defaults toFalse
) — Whether to remove the checkpoint directory after merging.
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
SHARDED_STATE_DICT
was used for the model. Weights will be saved to {output_path}/model.safetensors
if
safe_serialization
else pytorch_model.bin
.
Note: this is a CPU-bound process.
class accelerate.FullyShardedDataParallelPlugin
< source >( sharding_strategy: typing.Any = None backward_prefetch: typing.Any = None mixed_precision_policy: typing.Any = None auto_wrap_policy: Optional = None cpu_offload: typing.Any = None ignored_modules: Optional = None state_dict_type: typing.Any = None state_dict_config: typing.Any = None optim_state_dict_config: typing.Any = None limit_all_gathers: bool = True use_orig_params: bool = True param_init_fn: Optional = None sync_module_states: bool = True forward_prefetch: bool = False activation_checkpointing: bool = False )
This plugin is used to enable fully sharded data parallelism.