A guide to Efficient Multi-GPU Training

-



Training large models across multiple GPUs could be difficult attributable to the complexities of various parallelism strategies. In Speed up, along with Axolotl, we now have integrated a fast and simple method to use any combination of parallelism strategies in your training script!

Here is learn how to add it to your training script:

from transformers import AutoModelForCausalLM
from speed up import Accelerator
from speed up.parallelism_config import ParallelismConfig
from speed up.utils import FullyShardedDataParallelPlugin



pc = ParallelismConfig(
    dp_shard_size=2, 
    dp_replicate_size=2, 
    cp_size=2, 
    tp_size=2, 
)

fsdp_plugin = FullyShardedDataParallelPlugin(
    fsdp_version=2,
    auto_wrap_policy="transformer_based_wrap",
    transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
    state_dict_type="SHARDED_STATE_DICT",
)

accelerator = Accelerator(
    parallelism_config=pc,
    fsdp_plugin=fsdp_plugin
)

model = AutoModelForCausalLM.from_pretrained(
    "NousResearch/Hermes-3-Llama-3.1-8B", 
    device_mesh=accelerator.torch_device_mesh
)

model = accelerator.prepare(model)

We have also included a more comprehensive end-to-end training script within the Speed up repo which demonstrates learn how to setup your dataloader, optimizer, and training loop, and learn how to save your model after training.

To further streamline fine-tuning models at scale and compose parallelism strategies with a wide range of fine-tuning techniques, we have also integrated this system into Axolotl. To enable you start straight away we have tested some example configs which you’ll modify to fit your needs – try one out with:


axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml

You can even take a look at the Axolotl ND-Parallelism docs for more details – adding ND parallel techniques to your existing configs is so simple as adding a number of of the next fields to your Axolotl config file:



dp_shard_size: 2

dp_replicate_size: 2

context_parallel_size: 2

tensor_parallel_size: 2

We have made it easy to configure the degrees of various parallelism strategies and the way they’re combined through the ParallelismConfig class in Speed up, or through config fields in Axolotl, but how can we know which configuration will work best for our use case? As we scale to training models with tens and even a whole lot of billions of parameters, the first challenge comes from understanding the various parallelism strategies and the way they interact to minimise communication overhead across devices. On this post, we’ll walk through how the various parallelism strategies work, and when and the way it is advisable to compose them.



Contents



Data Parallelism

Diagram for Data Parallel
Distributed Data Parallel replicates the whole model across each device, and evenly divides the info into sub-batches for every device. (Source: Martynas Å ubonis).

Data parallelism (DP) is probably the most common technique for training models across multiple GPUs, and involves replicating the model, gradients and optimizer states across each device, whilst evenly distributing data batches between GPUs, and synchronising gradients across devices before updating parameters. This will significantly increase throughput in comparison with single-device training, but requires that your model is capable of fit on a single device.

We will control the variety of replicas of the model with the dp_replicate_size parameter in Speed up’s ParallelismConfig or config field in Axolotl. It’s value noting that DP is a top-most-level parallelism strategy, meaning that if we use dp_replicate_size=2 and we compose it with other parallelism strategies, there can be 2 replicas of the model, each also influenced by the opposite parallelism strategies. For instance, if we use dp_replicate_size=2 and tp_size=2, we might have 2 replicas of the model, each with 2 tensor parallel shards.

We use the term shard to explain data on a single device which is a partition of a bigger piece of information.



Fully Sharded Data Parallelism

Diagram for Fully Sharded Data Parallel
Fully Sharded Data Parallel evenly divides each of the model’s parameters across each device, and, like DDP, evenly divides the info into sub-batches for every device. To finish a forward and backwards pass, FSDP must gather the weights of every parameter before the forwards/backwards pass in order that each device obtains a full copy of the parameter. (Source: Martynas Å ubonis).

What if our model is just too large to suit on a single device? Fully sharded data parallel (FSDP) addresses this issue by sharding (distributing evenly) the model’s weights, gradients, and optimizer states across GPUs (that is inspired by DeepSpeed’s ZeRO-3), whilst each device still receives its portion of the total batch of information. As it’s possible you’ll notice from the diagram above, somewhat than requiring a full copy of the whole model on each device, we only gather the weights for a single layer at a time before the forward pass, after which the weights could also be sharded again.

In this fashion, we trade memory usage for the communication overhead of gathering sharded parameters before each forward and backward pass, and reduce-scatter-ing local gradients. We will control this trade-off in FSDP by tuning the granularity at which parameters are gathered. On one extreme, we are able to gather and re-shard every layer of our model, which might lead to the bottom peak memory usage, but incur the very best communication costs. In practice, a standard approach is to collect the weights for a complete transformer decoder block at a time.

Whilst we are able to make further memory-compute trade-offs and offload model parameters and gradients to the CPU to coach larger models, this could be prohibitively slow. As an alternative, let’s consider how we are able to effectively utilise much more devices to coach larger models whilst maintaining high data throughput.

We use the term node to confer with a single machine which hosts multiple GPUs (as much as a maximum of 8), with fast intra-node communication channels using e.g. NVLink between GPUs. When using multiple nodes for training, we depend on relatively slower inter-node communication channels between machines using e.g. Infiniband. We also confer with the whole variety of devices in the method pool because the world size – e.g. a single node with 8 GPUs represents a world size of 8, and 4 nodes would represent a world size of 32.

When using FSDP across multiple nodes, we treat the whole set of devices across nodes as if we were training on a single node. For instance, with 4 nodes containing 8 GPUs each, we perform our sharding across 32 devices, and perform our collective all-reduce and reduce-scatter operations using each inter-and-intra-node communication backends. In this fashion, FSDP alone can scale to a considerable variety of GPUs with a big global batch size to extend data throughput. Nevertheless, there comes some extent where several challenges arise which will require composing FSDP with other parallelism techniques. We normally attempt to avoid doing FSDP across greater than a full node, because the communication overhead can grow to be too high, we’ll speak about learn how to address this within the section on Hybrid Sharded Data Parallelism.

You should utilize the dp_shard_size parameter in Speed up’s ParallelismConfig along with a prepared FullyShardedDataParallelPlugin, or set the dp_shard_size config field in Axolotl to set the degree of FSDP applied to your model.



Tensor Parallelism

Diagram for Tensor Parallel
Tensor Parallelism splits large linear layers across devices, typically using column-wise sharding for the primary layer and row-wise sharding for the following layer. This approach requires only a single AllReduce communication operation to mix the sharded outputs, minimizing communication overhead while distributing each memory and compute across devices inside a node.

Tensor Parallel (TP) is a sort of model parallelism technique, where shards of the model permanently survive separate devices, and in contrast to data parallel techniques, each device receives the same batch of information. TP works by distributing the computation of linear layers across devices, so each device only computes a portion of the matrix multiplication. This method works best when there are large linear layers, similar to the feed-forward layers in transformer models, which could be split across devices. We may also use TP on each of the query, key, value, and output projections in the eye layers with almost no extra communication cost.

To attain the most effective performance, parameters of consecutive layers could be distributed in a selected fashion, minimizing the required communication. When working with pairs of linear layers, we are able to split the primary layer column-wise, and the following layer row-wise, allowing us to compute the output with only a single all-reduce operation to mix the sharded outputs.

Unlike the dynamic sharding behaviour of FSDP, TP creates static memory partitions which lead to a relentless memory usage reduction scaling with the TP group size. This becomes crucial for enormous models where even a single decoder layer is just too large to suit into memory throughout the FSDP all-gather (recall that common practice in FSDP is to collect the weights of a complete decoder layer at a time). Nevertheless, unlike FSDP which scales relatively linearly across nodes (up to some extent – ~512 GPUs on a homogenous cluster, significantly less across lower-bandwidth connections), TP is just effective inside the boundaries of a single node. TP requires frequent activation synchronization between devices during computation, as each device computes only a portion of the output, requiring the outputs from other devices to be communicated before continuing the forward pass. Thus, if we want to utilise TP in a multi-node setup, we must consider composing TP with other parallelism techniques, while keeping TP only inside a single node. Because of its large communications overhead, TP shouldn’t be beneficial for PCIe linked GPUs.

In Speed up, the TP size is configured through tp_size in ParallelismConfig, whilst in Axolotl you need to use the tensor_parallel_size config field.



Context Parallelism

Recently, reasoning capabilities in LLMs resulted in sequence lengths skyrocketing as models use an increasing number of tokens to resolve complex tasks. To attain this behaviour through fine-tuning, we want a method to train models on very large sequence lengths – which may sometimes reach as much as 1,000,000 tokens!

Because the attention operation in transformers scales quadratically with context length, this becomes unattainable on a single GPU. For instance, when fine-tuning a comparatively small model similar to Mistral-7B (which uses 32 attention heads), if we use a sequence length of 128k a single attention matrix will utilise 128k * 128k * 2 bytes * num_heads=32 = ~32GB * 32 = ~1TB of activations memory! Whilst this instance shouldn’t be realistic when using optimised attention implementations similar to FlashAttention, it helps illustrate the expansion in memory requirements from increasing the context length.

With context parallelism (CP), we are able to shard the inputs across the sequence dimension, leading to each device only processing a piece of the total context and computing a smaller portion of the total, prohibitively large, attention matrix. To see how this works, recall that the eye computation is described by the equation:
Attention(Q,K,V)=softmax(QKT)V text{Attention}(Q, K, V) = text{softmax}(QK^T)V

Where Q Q , K K , and V V are the query, key, and value matrices respectively. Each query vector (row, or input embedding) of Q Q must compute the eye scores against every key vector of K K in the whole sequence to accurately apply the softmax normalisation. These attention scores are then weighted with all value vectors in V V .

The crucial detail here lies within the incontrovertible fact that each row in Q Q can compute its attention rating independently of each other, but each query vector still requires the total K K and V V matrices. In other words, given an input with sequence length $n$, we are able to expand our above attention equation as:

Attention(Q,K,V)1=softmax(Q1KT)VAttention(Q,K,V)2=softmax(Q2KT)Vâ‹®Attention(Q,K,V)n=softmax(QnKT)V begin{align} text{Attention}(Q, K, V)_1 &= text{softmax}(Q_1 K^T) V text{Attention}(Q, K, V)_2 &= text{softmax}(Q_2 K^T) V &vdots text{Attention}(Q, K, V)_n &= text{softmax}(Q_n K^T) V end{align}

where we denote each row of the query matrix as Q1,Q2,...,Qn Q_1, Q_2, …, Q_n

Once we shard the inputs across devices, the resulting Q Q , K K , and V V matrices (computed from these input shards) are also routinely sharded along the sequence dimension – each GPU computes queries, keys, and values just for its portion of the sequence. For instance, with a world size of W W GPUs and sequence length n n :

  • GPU 0 computes Q1:n/W Q_{1:n/W}
  • GPU 1 computes Qn/W+1:2n/W Q_{n/W+1:2n/W}
  • GPU (W−1) (W-1)

How can we make sure the attention is computed accurately? As established above, each device only needs its own shard of Q Q , but requires the total K K and V V matrices to compute the eye accurately. We will achieve this by utilizing a way called RingAttention, which works as follows:

  1. Initially, each GPU holds its shard of Q Q , K K , V V (e.g., GPU 0 holds Q1:n/W Q_{1:n/W}
  2. Each GPU then computes a partial attention matrix Ai,j A_{i,j}
  3. Each GPU sends its shard of K K , V V to the subsequent GPU within the ring.
  4. Each GPU receives a unique shard of K K , V V from the previous GPU within the ring.
  5. Each GPU computes additional partial attention matrices Ai,j+1 A_{i,j+1}
  6. Each GPU repeats this process until all shards of K K , V V have been received and all partial
    attention matrices Ai,∗ A_{i,*}
Diagram for Context Parallel
Context Parallelism shards the input sequence across GPUs, with each device holding queries and key-value pairs for its assigned segment. Ring-attention circulates K,V shards between GPUs (shown by the arrows), allowing each query to compute attention scores against keys and values from the whole sequence. The ultimate attention output combines information from all sequence positions while distributing memory and compute across devices.

Speed up enables this with the accelerator.maybe_context_parallel decorator, which can be showcased within the Speed up example script. You can even learn more about how it really works and its limitations in our CP concept guide.

Much like TP, in Speed up the CP size is configured through cp_size in ParallelismConfig, whilst in Axolotl you need to use the context_parallel_size config field.



ND Parallelisms

Within the multi-node setting, data parallel techniques similar to FSDP treat the whole network topology as if it existed along a single dimension. You might find this approach limiting for a wide range of reasons:

  • When scaling to more nodes, FSDP’s collective operations grow to be bottlenecked by inter-node latency, making training prohibitively slow.
  • As we mentioned above, massive models could have decoder layers which cannot fit into GPU memory, or which could also be too large to perform a forward pass with, even in a sharded state.
  • It could possibly be unattainable to attain your ideal batch size – either the batch becomes too large for pure data parallelism to handle efficiently, or too small attributable to memory constraints from model size.

To attempt to address a few of these problems, we are able to consider multi-node clusters as having a two-dimensional topology: fast intra-node communication between devices along one axis, and comparatively slower inter-node communication along one other axis. Let’s consider how we are able to compose the parallelism techniques we’ve introduced thus far to make the most of this.



Hybrid Sharded Data Parallelism

Diagram for Hybrid Sharded Data Parallel
Hybrid Sharded Data Parallelism performs FSDP inside each replica group and synchronizes gradients across replica groups via AllReduce, combining the memory efficiency of FSDP with the communication efficiency of DP across nodes.

Hybrid Sharded Data Parallelism (HSDP) is a sort of 2D parallelism which performs FSDP inside a node, and DP across nodes – that’s to say the model is replicated across each node, and sharded using FSDP inside each node. This permits the greater communication overhead of FSDP to utilize the faster intra-node links, whilst DP minimises the slower inter-node communication overhead to a single gradient synchronisation step. You may consider this approach in case you were facing problem 1 and wished to hurry up training at the fee of increased memory usage.

It’s essential to notice that we are able to freely configure the form of our 2D network topology, as we aren’t constrained to the size being aligned with physical node boundaries – you would possibly apply FSDP across 2 nodes whilst replicating across groups of two nodes, which might lead to lower memory usage but slower throughput, but still reduce the intra-node FSDP communication overhead by an element of two. This can be a knob we encourage you to tune to your specific hardware setup and fine-tuning needs.

You possibly can enable HSDP by defining each dp_shard_size and dp_replicate_size in Speed up’s ParallelismConfig or through Axolotl’s config fields.



Fully Sharded Data Parallelism + Tensor Parallelism

As we mentioned earlier, TP needs to be applied inside a node to utilize the high-bandwidth intra-node communications, thus, combining TP and FSDP involves sharding the model across nodes using FSDP, and inside a node using TP. To a certain degree, this potentially offers a neat solution to all three of the problems above: the latency costs from FSDP could possibly be reduced by an element of 8, layers which might be too large to suit on a single device at the moment are evenly distributed across devices, and since each TP group receives the same batch of information, we may also reduce our global batch size by an element of 8. Nevertheless, if this stays insufficient, we’re unable to extend the TP size across nodes and must consider another approach.

In Speed up you possibly can mix TP and FSDP by defining each dp_shard_size and tp_size in ParallelismConfig, whilst in Axolotl you possibly can add each of the dp_shard_size and tensor_parallel_size config fields.



Fully Sharded Data Parallelism + Context Parallelism

This can be a 2D parallelism strategy that mixes FSDP and CP, and while this shouldn’t be very commonly used as CP already combines with FSDP (more on why within the speed up concept guide), it could possibly be useful in some cases i.e. when requiring a big sequence length, consequently requiring a big cp_size. If this still doesn’t fit into your memory budget, you possibly can apply FSDP on top of this, further reducing the memory usage.

In Speed up you possibly can mix CP and FSDP by defining each dp_shard_size and cp_size in ParallelismConfig, whilst in Axolotl you possibly can add each of the dp_shard_size and context_parallel_size config fields.



Hybrid Sharded Data Parallelism + Tensor Parallelism

With a sufficiently large world size (note: while the minimum world size for 3D parallelism is 8, it’s best at much larger scales), we are able to consider combining HSDP with TP which creates a hierarchy where DP first replicates the model across groups of nodes, FSDP then shards the model inside each group, and TP splits individual layers inside each node. You may consider this approach when facing all the scaling constraints we mentioned above, because it provides probably the most flexibility to adapt to your specific training setup by making trade-offs between memory usage and throughput.

In Speed up you possibly can mix HSDP and TP by defining all of dp_shard_size, dp_replicate_size, and tp_size in ParallelismConfig. Similarly in Axolotl you possibly can add all the dp_shard_size, dp_replicate_size, and tensor_parallel_size config fields.



Usage notes

There are additional ways to mix multiple parallelisms which we’ve not covered, similar to 4D parallel using HSDP + TP + CP, but they operate very similarly to the techniques we have already covered. Most of all, we encourage you to play with different techniques and configurations – that is the most effective method to gain an intuition for the various ways by which you possibly can make memory/throughput trade-offs.

Below are some additional suggestions it’s possible you’ll find useful when working in distributed settings:

  • When using FSDP and dealing with models which might be too large to slot in a single device, enabling each CPU RAM efficient loading and sharded state dict checkpointing technique is crucial. You possibly can enable this through the cpu_ram_efficient_loading and state_dict_type parameters in Speed up’s FullyShardedDataParallelPlugin,

    fsdp2_plugin = FullyShardedDataParallelPlugin(
        fsdp_version=2,
        auto_wrap_policy="transformer_based_wrap",
        transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
        state_dict_type="SHARDED_STATE_DICT", 
        cpu_ram_efficient_loading=True
    )
    

    or through the cpu_ram_efficient_loading and state_dict_type config fields contained in the fsdp_config in Axolotl:

    fsdp_version: 2
    fsdp_config:
      auto_wrap_policy: TRANSFORMER_BASED_WRAP
      transformer_layer_cls_to_wrap: LlamaDecoderLayer
      state_dict_type: SHARDED_STATE_DICT
      cpu_ram_efficient_loading: True
    
  • The entire batch size used during training plays a vital think about training stability, memory usage, and data throughput. When using DP and/or FSDP the effective batch size is calculated as:

    effective_batch_size = micro_batch_size * gradient_accumulation_steps * dp_world_size.

    where dp_world_size = (dp_shard_size * dp_replicate_size) / tp_size. You possibly can increase your batch size by increasing your total micro batch size or gradient accumulation steps in your training loop, or setting the micro_batch_size and gradient_accumulation_steps config fields in Axolotl, or increasing the whole dp_world_size by adding more GPUs. As we mentioned above, this imposes a minimum total batch size of dp_world_size – when using pure DP/FSDP, this will probably be your total world size, and if this is just too high the one method to decrease the whole batch size is by introducing tensor parallelism. Finally, with a set variety of GPUs and in memory-constrained scenarios, we recommend increasing gradient_accumulation_steps as a substitute of micro_batch_size to attain larger effective batch sizes, and vice-versa.

  • Correspondingly, when your effective batch size increases attributable to introducing data parallelism, it’s best to scale your learning rate to keep up training stability. Common approaches include linear scaling scaled_lr = base_lr * (effective_batch_size / base_batch_size) or square root scaling scaled_lr = base_lr * sqrt(effective_batch_size / base_batch_size).

  • When memory constraints persist even with parallelism strategies, gradient checkpointing can provide additional memory savings by trading compute for memory. Throughout the forward pass, only a subset of activations are kept in memory (typically at transformer block boundaries), and intermediate activations are recomputed throughout the backward pass. This method works seamlessly with all parallelism strategies covered above. In Speed up, you possibly can enable it by setting activation_checkpointing=true in FullyShardedDataParallelPlugin:

    fsdp2_plugin = FullyShardedDataParallelPlugin(
        fsdp_version=2,
        auto_wrap_policy="transformer_based_wrap",
        transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
        state_dict_type="SHARDED_STATE_DICT", 
        cpu_ram_efficient_loading=True,
        activation_checkpointing=True
    )
    

    and similarly in Axolotl:

    fsdp_version: 2
    fsdp_config:
      auto_wrap_policy: TRANSFORMER_BASED_WRAP
      transformer_layer_cls_to_wrap: LlamaDecoderLayer
      state_dict_type: SHARDED_STATE_DICT
      cpu_ram_efficient_loading: True
      activation_checkpointing: True
    

    Note that gradient checkpointing typically increases training time by ~20-30% attributable to activation recomputation, but can reduce activation memory by 60-80%, making it particularly invaluable when training very large models or using long sequence lengths.



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x