Optimum Neuron Distributed
The optimum.neuron.distributed
module provides a set of tools to perform distributed training and inference.
Parallelization
The main task in distributed training / inference is being able to shard things such as the model weights, the gradient, and/or the optimizer state. We built Parallelizer
classes to handle the sharding.
Base Parallelizer
The Parallelizer
class is the base abstract class being derived for every model supporting model parallelism. It provides methods to parallelize the model and save and load sharded checkpoints.
Base abstract class that handles model parallelism.
_parallelize
< source >( model: PreTrainedModel device: Optional = None parallelize_embeddings: bool = True sequence_parallel_enabled: bool = False should_parallelize_layer_predicate_func: Optional = None **parallel_layer_specific_kwargs ) → PreTrainedModel
Parameters
- model (
PreTrainedModel
) — The model to parallelize. - device (
Optional[torch.device]
, defaults toNone
) — The device where the new parallel layers should be put. - parallelize_embeddings (
bool
, defaults toTrue
) — Whether or not the embeddings should be parallelized. This can be disabled in the case when the TP size does not divide the vocabulary size. - sequence_parallel_enabled (
bool
, defaults toFalse
) — Whether or not sequence parallelism is enabled. - should_parallelize_layer_predicate_func (Optional[Callable[[torch.nn.Module], bool]], defaults to
None
) — A function that takes a layer as input and returns a boolean specifying if the input layer should be parallelized. This is useful to skip unnecessary parallelization, for pipeline parallelism for instance. - **parallel_layer_specific_kwargs (
Dict[str, Any]
) — Keyword arguments specific to some parallel layers, they will be ignored by the other parallel layers.
Returns
PreTrainedModel
The parallelized model.
Parallelizes the model by transforming regular layer into their parallel counterparts. Each concrete class must implement it.
parallelize
< source >( model: Union device: Optional = None parallelize_embeddings: bool = True sequence_parallel_enabled: bool = False kv_size_multiplier: Optional = None pipeline_parallel_input_names: Union = None pipeline_parallel_num_microbatches: int = 1 pipeline_parallel_use_zero1_optimizer: bool = False pipeline_parallel_gradient_checkpointing_enabled: bool = False checkpoint_dir: Union = None num_local_ranks_per_step: int = 8 ) → PreTrainedModel
Parameters
- model (
Union[PreTrainedModel, NeuronPeftModel]
) — The model to parallelize. - device (
Optional[torch.device]
, defaults toNone
) — The device where the new parallel layers should be put. - parallelize_embeddings (
bool
, defaults toTrue
) — Whether or not the embeddings should be parallelized. This can be disabled in the case when the TP size does not divide the vocabulary size. - sequence_parallel_enabled (
bool
, defaults toFalse
) — Whether or not sequence parallelism is enabled. - kv_size_multiplier (
Optional[int], defaults to
None`) — The number of times to replicate the KV heads when the TP size is bigger than the number of KV heads. If left unspecified, the smallest multiplier that makes the number of KV heads divisible by the TP size will be used. - pipeline_parallel_num_microbatches (
int
, defaults to 1) — The number of microbatches used for pipeline execution. - pipeline_parallel_use_zero1_optimizer (
bool
, defaults toFalse
) — When zero-1 optimizer is used, set this to True, so the PP model will understand that zero-1 optimizer will handle data parallel gradient averaging. - pipeline_parallel_gradient_checkpointing_enabled (
bool
, defaults toFalse
) — Whether or not gradient checkpointing should be enabled when doing pipeline parallelism. - checkpoint_dir (
Optional[Union[str, Path]]
) — Path to a sharded checkpoint. If specified, the checkpoint weights will be loaded to the parallelized model. - num_local_ranks_per_step (
int
, defaults to8
) — Corresponds to the number of local ranks that can initialize and load the model weights at the same time. If the value is inferior to 0, the maximum number of ranks will be used.
Returns
PreTrainedModel
The parallelized model.
Parallelizes the model by transforming regular layer into their parallel counterparts using
cls._parallelize()
.
It also makes sure that each parameter has loaded its weights or has been initialized if there is no pre-trained weights associated to it.
optimizer_for_mp
< source >( optimizer: torch.optim.Optimizer orig_param_to_parallel_param_on_xla: Mapping ) → torch.optim.Optimizer
Parameters
- optimizer (
torch.optim.Optimizer
) — The original optimizer. - orig_param_to_parallel_param_on_xla (
Mapping[int, torch.nn.Parameter]
) — A mapping (e.g. dict-like) that maps the id of a parameter inoptimizer
to the id of its parallelized counterpart on an XLA device.
Returns
torch.optim.Optimizer
The tensor parallelism ready optimizer.
Creates an optimizer ready for a parallelized model from an existing optimizer.
There are two cases:
- The optimizer has been created via a lazy constructor from
optimum.neuron.distributed.utils.make_optimizer_constructor_lazy
, it which case the exactly intended optimizer is created for tensor parallelism. - The optimizer was created with a regular constructor. In this case the optimizer for tensor parallelism is created as close as possible to what was intended but that is not guaranteed.
save_model_sharded_checkpoint
< source >( model: Union output_dir: Union optimizer: Optional = None use_xser: bool = True async_save: bool = False num_local_ranks_per_step: int = 8 )
Selecting Model-Specific Parallelizer Classes
Each model that supports parallelization in optimum-neuron
has its own derived Parallelizer
class. The factory class ParallelizersManager
allows you to retrieve such model-specific Parallelizer
s easily.
Provides the list of supported model types for parallelization.
is_model_supported
< source >( model_type_or_model: Union )
Returns a tuple of 3 booleans where:
- The first element indicates if tensor parallelism can be used for this model,
- The second element indicates if sequence parallelism can be used on top of tensor parallelism for this model,
- The third element indicates if pipeline parallelism can be used for this model.
parallelizer_for_model
< source >( model_type_or_model: Union )
Returns the parallelizer class associated to the model.
Utils
Lazy Loading
Distributed training / inference is usually needed when the model is too big to fit in one device. Tools that allow for lazy loading of model weights and optimizer states are thus needed to avoid going out-of-memory before parallelization.
optimum.neuron.distributed.lazy_load_for_parallelism
< source >( tensor_parallel_size: int = 1 pipeline_parallel_size: int = 1 )
Context manager that makes the loading of a model lazy for model parallelism:
- Every
torch.nn.Linear
is put on thetorch.device("meta")
device, meaning that it takes no memory to instantiate. - Every
torch.nn.Embedding
is also put on thetorch.device("meta")
device. - No state dict is actually loaded, instead a weight map is created and attached to the model. For more
information, read the
optimum.neuron.distributed.utils.from_pretrained_for_mp
docstring.
If both tensor_parallel_size
and pipeline_parallel_size
are set to 1, no lazy loading is performed.
Transforms an optimizer constructor (optimizer class) to make it lazy by not initializing the parameters. This makes the optimizer lightweight and usable to create a “real” optimizer once the model has been parallelized.