Training large models across multiple GPUs could be difficult attributable to the complexities of various parallelism strategies. In Speed up, along with Axolotl, we now have integrated a fast and simple method to use any combination of parallelism strategies in your training script!
Here is learn how to add it to your training script:
from transformers import AutoModelForCausalLM
from speed up import Accelerator
from speed up.parallelism_config import ParallelismConfig
from speed up.utils import FullyShardedDataParallelPlugin
pc = ParallelismConfig(
dp_shard_size=2,
dp_replicate_size=2,
cp_size=2,
tp_size=2,
)
fsdp_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
state_dict_type="SHARDED_STATE_DICT",
)
accelerator = Accelerator(
parallelism_config=pc,
fsdp_plugin=fsdp_plugin
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Hermes-3-Llama-3.1-8B",
device_mesh=accelerator.torch_device_mesh
)
model = accelerator.prepare(model)
We have also included a more comprehensive end-to-end training script within the Speed up repo which demonstrates learn how to setup your dataloader, optimizer, and training loop, and learn how to save your model after training.
To further streamline fine-tuning models at scale and compose parallelism strategies with a wide range of fine-tuning techniques, we have also integrated this system into Axolotl. To enable you start straight away we have tested some example configs which you’ll modify to fit your needs – try one out with:
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
You can even take a look at the Axolotl ND-Parallelism docs for more details – adding ND parallel techniques to your existing configs is so simple as adding a number of of the next fields to your Axolotl config file:
dp_shard_size: 2
dp_replicate_size: 2
context_parallel_size: 2
tensor_parallel_size: 2
We have made it easy to configure the degrees of various parallelism strategies and the way they’re combined through the ParallelismConfig class in Speed up, or through config fields in Axolotl, but how can we know which configuration will work best for our use case? As we scale to training models with tens and even a whole lot of billions of parameters, the first challenge comes from understanding the various parallelism strategies and the way they interact to minimise communication overhead across devices. On this post, we’ll walk through how the various parallelism strategies work, and when and the way it is advisable to compose them.
Contents
Data Parallelism

Data parallelism (DP) is probably the most common technique for training models across multiple GPUs, and involves replicating the model, gradients and optimizer states across each device, whilst evenly distributing data batches between GPUs, and synchronising gradients across devices before updating parameters. This will significantly increase throughput in comparison with single-device training, but requires that your model is capable of fit on a single device.
We will control the variety of replicas of the model with the dp_replicate_size parameter in Speed up’s ParallelismConfig or config field in Axolotl. It’s value noting that DP is a top-most-level parallelism strategy, meaning that if we use dp_replicate_size=2 and we compose it with other parallelism strategies, there can be 2 replicas of the model, each also influenced by the opposite parallelism strategies. For instance, if we use dp_replicate_size=2 and tp_size=2, we might have 2 replicas of the model, each with 2 tensor parallel shards.
We use the term shard to explain data on a single device which is a partition of a bigger piece of information.
Fully Sharded Data Parallelism

What if our model is just too large to suit on a single device? Fully sharded data parallel (FSDP) addresses this issue by sharding (distributing evenly) the model’s weights, gradients, and optimizer states across GPUs (that is inspired by DeepSpeed’s ZeRO-3), whilst each device still receives its portion of the total batch of information. As it’s possible you’ll notice from the diagram above, somewhat than requiring a full copy of the whole model on each device, we only gather the weights for a single layer at a time before the forward pass, after which the weights could also be sharded again.
In this fashion, we trade memory usage for the communication overhead of gathering sharded parameters before each forward and backward pass, and reduce-scatter-ing local gradients. We will control this trade-off in FSDP by tuning the granularity at which parameters are gathered. On one extreme, we are able to gather and re-shard every layer of our model, which might lead to the bottom peak memory usage, but incur the very best communication costs. In practice, a standard approach is to collect the weights for a complete transformer decoder block at a time.
Whilst we are able to make further memory-compute trade-offs and offload model parameters and gradients to the CPU to coach larger models, this could be prohibitively slow. As an alternative, let’s consider how we are able to effectively utilise much more devices to coach larger models whilst maintaining high data throughput.
We use the term node to confer with a single machine which hosts multiple GPUs (as much as a maximum of 8), with fast intra-node communication channels using e.g. NVLink between GPUs. When using multiple nodes for training, we depend on relatively slower inter-node communication channels between machines using e.g. Infiniband. We also confer with the whole variety of devices in the method pool because the world size – e.g. a single node with 8 GPUs represents a world size of 8, and 4 nodes would represent a world size of 32.
When using FSDP across multiple nodes, we treat the whole set of devices across nodes as if we were training on a single node. For instance, with 4 nodes containing 8 GPUs each, we perform our sharding across 32 devices, and perform our collective all-reduce and reduce-scatter operations using each inter-and-intra-node communication backends. In this fashion, FSDP alone can scale to a considerable variety of GPUs with a big global batch size to extend data throughput. Nevertheless, there comes some extent where several challenges arise which will require composing FSDP with other parallelism techniques. We normally attempt to avoid doing FSDP across greater than a full node, because the communication overhead can grow to be too high, we’ll speak about learn how to address this within the section on Hybrid Sharded Data Parallelism.
You should utilize the
dp_shard_sizeparameter in Speed up’sParallelismConfigalong with a preparedFullyShardedDataParallelPlugin, or set thedp_shard_sizeconfig field in Axolotl to set the degree of FSDP applied to your model.
Tensor Parallelism

Tensor Parallel (TP) is a sort of model parallelism technique, where shards of the model permanently survive separate devices, and in contrast to data parallel techniques, each device receives the same batch of information. TP works by distributing the computation of linear layers across devices, so each device only computes a portion of the matrix multiplication. This method works best when there are large linear layers, similar to the feed-forward layers in transformer models, which could be split across devices. We may also use TP on each of the query, key, value, and output projections in the eye layers with almost no extra communication cost.
To attain the most effective performance, parameters of consecutive layers could be distributed in a selected fashion, minimizing the required communication. When working with pairs of linear layers, we are able to split the primary layer column-wise, and the following layer row-wise, allowing us to compute the output with only a single all-reduce operation to mix the sharded outputs.
Unlike the dynamic sharding behaviour of FSDP, TP creates static memory partitions which lead to a relentless memory usage reduction scaling with the TP group size. This becomes crucial for enormous models where even a single decoder layer is just too large to suit into memory throughout the FSDP all-gather (recall that common practice in FSDP is to collect the weights of a complete decoder layer at a time). Nevertheless, unlike FSDP which scales relatively linearly across nodes (up to some extent – ~512 GPUs on a homogenous cluster, significantly less across lower-bandwidth connections), TP is just effective inside the boundaries of a single node. TP requires frequent activation synchronization between devices during computation, as each device computes only a portion of the output, requiring the outputs from other devices to be communicated before continuing the forward pass. Thus, if we want to utilise TP in a multi-node setup, we must consider composing TP with other parallelism techniques, while keeping TP only inside a single node. Because of its large communications overhead, TP shouldn’t be beneficial for PCIe linked GPUs.
In Speed up, the TP size is configured through
tp_sizeinParallelismConfig, whilst in Axolotl you need to use thetensor_parallel_sizeconfig field.
Context Parallelism
Recently, reasoning capabilities in LLMs resulted in sequence lengths skyrocketing as models use an increasing number of tokens to resolve complex tasks. To attain this behaviour through fine-tuning, we want a method to train models on very large sequence lengths – which may sometimes reach as much as 1,000,000 tokens!
Because the attention operation in transformers scales quadratically with context length, this becomes unattainable on a single GPU. For instance, when fine-tuning a comparatively small model similar to Mistral-7B (which uses 32 attention heads), if we use a sequence length of 128k a single attention matrix will utilise 128k * 128k * 2 bytes * num_heads=32 = ~32GB * 32 = ~1TB of activations memory! Whilst this instance shouldn’t be realistic when using optimised attention implementations similar to FlashAttention, it helps illustrate the expansion in memory requirements from increasing the context length.
With context parallelism (CP), we are able to shard the inputs across the sequence dimension, leading to each device only processing a piece of the total context and computing a smaller portion of the total, prohibitively large, attention matrix. To see how this works, recall that the eye computation is described by the equation:
Where , , and are the query, key, and value matrices respectively. Each query vector (row, or input embedding) of must compute the eye scores against every key vector of in the whole sequence to accurately apply the softmax normalisation. These attention scores are then weighted with all value vectors in .
The crucial detail here lies within the incontrovertible fact that each row in can compute its attention rating independently of each other, but each query vector still requires the total and matrices. In other words, given an input with sequence length $n$, we are able to expand our above attention equation as:
where we denote each row of the query matrix as . This could be generalized as:
Once we shard the inputs across devices, the resulting , , and matrices (computed from these input shards) are also routinely sharded along the sequence dimension – each GPU computes queries, keys, and values just for its portion of the sequence. For instance, with a world size of GPUs and sequence length :
- GPU 0 computes , ,
- GPU 1 computes , ,
- …
- GPU computes , ,
How can we make sure the attention is computed accurately? As established above, each device only needs its own shard of , but requires the total and matrices to compute the eye accurately. We will achieve this by utilizing a way called RingAttention, which works as follows:
- Initially, each GPU holds its shard of , , (e.g., GPU 0 holds , , ).
- Each GPU then computes a partial attention matrix for its shard of and its local
shard of , . - Each GPU sends its shard of , to the subsequent GPU within the ring.
- Each GPU receives a unique shard of , from the previous GPU within the ring.
- Each GPU computes additional partial attention matrices , , etc. using
the received , shards. - Each GPU repeats this process until all shards of , have been received and all partial
attention matrices have been computed.

Speed up enables this with the accelerator.maybe_context_parallel decorator, which can be showcased within the Speed up example script. You can even learn more about how it really works and its limitations in our CP concept guide.
Much like TP, in Speed up the CP size is configured through
cp_sizeinParallelismConfig, whilst in Axolotl you need to use thecontext_parallel_sizeconfig field.
ND Parallelisms
Within the multi-node setting, data parallel techniques similar to FSDP treat the whole network topology as if it existed along a single dimension. You might find this approach limiting for a wide range of reasons:
- When scaling to more nodes, FSDP’s collective operations grow to be bottlenecked by inter-node latency, making training prohibitively slow.
- As we mentioned above, massive models could have decoder layers which cannot fit into GPU memory, or which could also be too large to perform a forward pass with, even in a sharded state.
- It could possibly be unattainable to attain your ideal batch size – either the batch becomes too large for pure data parallelism to handle efficiently, or too small attributable to memory constraints from model size.
To attempt to address a few of these problems, we are able to consider multi-node clusters as having a two-dimensional topology: fast intra-node communication between devices along one axis, and comparatively slower inter-node communication along one other axis. Let’s consider how we are able to compose the parallelism techniques we’ve introduced thus far to make the most of this.
Hybrid Sharded Data Parallelism

Hybrid Sharded Data Parallelism (HSDP) is a sort of 2D parallelism which performs FSDP inside a node, and DP across nodes – that’s to say the model is replicated across each node, and sharded using FSDP inside each node. This permits the greater communication overhead of FSDP to utilize the faster intra-node links, whilst DP minimises the slower inter-node communication overhead to a single gradient synchronisation step. You may consider this approach in case you were facing problem 1 and wished to hurry up training at the fee of increased memory usage.
It’s essential to notice that we are able to freely configure the form of our 2D network topology, as we aren’t constrained to the size being aligned with physical node boundaries – you would possibly apply FSDP across 2 nodes whilst replicating across groups of two nodes, which might lead to lower memory usage but slower throughput, but still reduce the intra-node FSDP communication overhead by an element of two. This can be a knob we encourage you to tune to your specific hardware setup and fine-tuning needs.
You possibly can enable HSDP by defining each
dp_shard_sizeanddp_replicate_sizein Speed up’sParallelismConfigor through Axolotl’s config fields.
Fully Sharded Data Parallelism + Tensor Parallelism
As we mentioned earlier, TP needs to be applied inside a node to utilize the high-bandwidth intra-node communications, thus, combining TP and FSDP involves sharding the model across nodes using FSDP, and inside a node using TP. To a certain degree, this potentially offers a neat solution to all three of the problems above: the latency costs from FSDP could possibly be reduced by an element of 8, layers which might be too large to suit on a single device at the moment are evenly distributed across devices, and since each TP group receives the same batch of information, we may also reduce our global batch size by an element of 8. Nevertheless, if this stays insufficient, we’re unable to extend the TP size across nodes and must consider another approach.
In Speed up you possibly can mix TP and FSDP by defining each
dp_shard_sizeandtp_sizeinParallelismConfig, whilst in Axolotl you possibly can add each of thedp_shard_sizeandtensor_parallel_sizeconfig fields.
Fully Sharded Data Parallelism + Context Parallelism
This can be a 2D parallelism strategy that mixes FSDP and CP, and while this shouldn’t be very commonly used as CP already combines with FSDP (more on why within the speed up concept guide), it could possibly be useful in some cases i.e. when requiring a big sequence length, consequently requiring a big cp_size. If this still doesn’t fit into your memory budget, you possibly can apply FSDP on top of this, further reducing the memory usage.
In Speed up you possibly can mix CP and FSDP by defining each
dp_shard_sizeandcp_sizeinParallelismConfig, whilst in Axolotl you possibly can add each of thedp_shard_sizeandcontext_parallel_sizeconfig fields.
Hybrid Sharded Data Parallelism + Tensor Parallelism
With a sufficiently large world size (note: while the minimum world size for 3D parallelism is 8, it’s best at much larger scales), we are able to consider combining HSDP with TP which creates a hierarchy where DP first replicates the model across groups of nodes, FSDP then shards the model inside each group, and TP splits individual layers inside each node. You may consider this approach when facing all the scaling constraints we mentioned above, because it provides probably the most flexibility to adapt to your specific training setup by making trade-offs between memory usage and throughput.
In Speed up you possibly can mix HSDP and TP by defining all of
dp_shard_size,dp_replicate_size, andtp_sizeinParallelismConfig. Similarly in Axolotl you possibly can add all thedp_shard_size,dp_replicate_size, andtensor_parallel_sizeconfig fields.
Usage notes
There are additional ways to mix multiple parallelisms which we’ve not covered, similar to 4D parallel using HSDP + TP + CP, but they operate very similarly to the techniques we have already covered. Most of all, we encourage you to play with different techniques and configurations – that is the most effective method to gain an intuition for the various ways by which you possibly can make memory/throughput trade-offs.
Below are some additional suggestions it’s possible you’ll find useful when working in distributed settings:
-
When using FSDP and dealing with models which might be too large to slot in a single device, enabling each CPU RAM efficient loading and sharded state dict checkpointing technique is crucial. You possibly can enable this through the
cpu_ram_efficient_loadingandstate_dict_typeparameters in Speed up’sFullyShardedDataParallelPlugin,fsdp2_plugin = FullyShardedDataParallelPlugin( fsdp_version=2, auto_wrap_policy="transformer_based_wrap", transformer_cls_names_to_wrap=["LlamaDecoderLayer"], state_dict_type="SHARDED_STATE_DICT", cpu_ram_efficient_loading=True )or through the
cpu_ram_efficient_loadingandstate_dict_typeconfig fields contained in thefsdp_configin Axolotl:fsdp_version: 2 fsdp_config: auto_wrap_policy: TRANSFORMER_BASED_WRAP transformer_layer_cls_to_wrap: LlamaDecoderLayer state_dict_type: SHARDED_STATE_DICT cpu_ram_efficient_loading: True -
The entire batch size used during training plays a vital think about training stability, memory usage, and data throughput. When using DP and/or FSDP the effective batch size is calculated as:
effective_batch_size = micro_batch_size * gradient_accumulation_steps * dp_world_size.where
dp_world_size = (dp_shard_size * dp_replicate_size) / tp_size. You possibly can increase your batch size by increasing your total micro batch size or gradient accumulation steps in your training loop, or setting themicro_batch_sizeandgradient_accumulation_stepsconfig fields in Axolotl, or increasing the wholedp_world_sizeby adding more GPUs. As we mentioned above, this imposes a minimum total batch size ofdp_world_size– when using pure DP/FSDP, this will probably be your total world size, and if this is just too high the one method to decrease the whole batch size is by introducing tensor parallelism. Finally, with a set variety of GPUs and in memory-constrained scenarios, we recommend increasinggradient_accumulation_stepsas a substitute ofmicro_batch_sizeto attain larger effective batch sizes, and vice-versa. -
Correspondingly, when your effective batch size increases attributable to introducing data parallelism, it’s best to scale your learning rate to keep up training stability. Common approaches include linear scaling
scaled_lr = base_lr * (effective_batch_size / base_batch_size)or square root scalingscaled_lr = base_lr * sqrt(effective_batch_size / base_batch_size). -
When memory constraints persist even with parallelism strategies, gradient checkpointing can provide additional memory savings by trading compute for memory. Throughout the forward pass, only a subset of activations are kept in memory (typically at transformer block boundaries), and intermediate activations are recomputed throughout the backward pass. This method works seamlessly with all parallelism strategies covered above. In Speed up, you possibly can enable it by setting
activation_checkpointing=trueinFullyShardedDataParallelPlugin:fsdp2_plugin = FullyShardedDataParallelPlugin( fsdp_version=2, auto_wrap_policy="transformer_based_wrap", transformer_cls_names_to_wrap=["LlamaDecoderLayer"], state_dict_type="SHARDED_STATE_DICT", cpu_ram_efficient_loading=True, activation_checkpointing=True )and similarly in Axolotl:
fsdp_version: 2 fsdp_config: auto_wrap_policy: TRANSFORMER_BASED_WRAP transformer_layer_cls_to_wrap: LlamaDecoderLayer state_dict_type: SHARDED_STATE_DICT cpu_ram_efficient_loading: True activation_checkpointing: TrueNote that gradient checkpointing typically increases training time by ~20-30% attributable to activation recomputation, but can reduce activation memory by 60-80%, making it particularly invaluable when training very large models or using long sequence lengths.
