AI in Multiple GPUs: Gradient Accumulation & Data Parallelism

-

is an element of a series about distributed AI across multiple GPUs:

Introduction

Distributed Data Parallelism (DDP) is the primary parallelization method we’ll have a look at. It’s the baseline approach that’s all the time utilized in distributed training settings, and it’s commonly combined with other parallelization techniques.

A Quick Neural Network Refresher

Training a neural network means running a forward pass, calculating the loss, backpropagating the gradients of every weight with respect to the loss function, and at last updating weights (what we call an optimization step). In PyTorch, it typically looks like this:

import torch

def training_loop(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_fn: callable,
):
    for i, batch in enumerate(dataloader):
        inputs, targets = batch
        output = model(inputs)  # Forward pass
        loss = loss_fn(output, targets)  # Compute loss
        loss.backward()  # Backward pass (compute gradients)
        optimizer.step()  # Update weights
        optimizer.zero_grad()  # Clear gradients for the subsequent step

Performing the optimization step on large amounts of coaching data generally gives more accurate gradient estimates, resulting in smoother training and potentially faster convergence. So ideally we could be taking each step after computing the gradients based on the complete training dataset. In practice, that’s rarely feasible in Deep Learning scenarios, as it could take too long to compute. As a substitute, we work with small chunks like  and .

  • Batch: Refers to the complete training set used for one optimization step.
  • Mini-batch: Refers to a small subset of the training data used for one optimization step.
  • Micro-batch: Refers to a subset of the mini-batch, we mix multiple micro-batches for one optimization step.

That is where Gradient Accumulation and Data Parallelism come into play. Although we don’t use the complete dataset for every step, we are able to use these techniques to substantially increase our mini-batch size.

Gradient Accumulation

Here’s how it really works: pick a big mini-batch that won’t slot in GPU memory, but then split it into  that do fit. For every micro-batch, run forward and backward passes, adding (accumulating) the computed gradients. Once all micro-batches are processed, perform a single optimization step using the averaged gradients.

Notice Gradient Accumulation isn’t a parallelization technique and doesn’t require multiple GPUs.

Image by writer: Gradient Accumulation animation

Implementing Gradient Accumulation from scratch is easy. Here’s what it looks like in an easy training loop:

import torch

def training_loop(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_fn: callable,
    grad_accum_steps: int,
):
    for i, batch in enumerate(dataloader):
        inputs, targets = batch
        output = model(inputs)
        loss = loss_fn(output, targets)
        loss.backward()  # Gradients get gathered (summed)

        # Only update weights after `grad_accum_steps` micro-batches
        if (i+1) % grad_accum_steps == 0:  # i+1 to avoid a step in the primary iteration when i=0
            optimizer.step()
            optimizer.zero_grad()

Notice we’re  performing multiple forward and backward passes before each optimization step, which requires longer training times. It will be nice if we could speed this up by processing multiple micro-batches in … that’s exactly what DDP does!

Distributed Data Parallelism (DDP)

For a reasonably small variety of GPUs (as much as ~8) DDP scales almost linearly, which is perfect. That signifies that when you double the variety of GPUs, you’ll be able to almost halve the training time (we already discussed Linear Scaling previously).

With DDP, multiple GPUs work together to process a bigger effective mini-batch, handling each micro-batch in parallel. The workflow looks like this:

  1. Split the mini-batch across GPUs.
  2. Each GPU runs its own forward and backward passes to compute gradients for its own data shard (micro-batch).
  3. Use an  operation (we previously learned about it in Collective operations) to average gradients across all GPUs.
  4. Each GPU applies the identical weight updates, keeping models in perfect sync.

This lets us train with much larger effective mini-batch sizes, resulting in more stable training and potentially faster convergence.

Image by writer: Distributed Data Parallel animation

Implementing DDP from scratch in PyTorch

Let’s do that step-by-step. In this primary iteration, we’re only syncing the gradients.

import torch


class DDPModelWrapper:
    def __init__(self, model: torch.nn.Module):
        self.model = model

    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def sync_gradients(self):
        # Iterate over parameter matrices within the model
        for param in self.model.parameters():  
            # Some parameters may be frozen and haven't got gradients
            if param.grad shouldn't be None:
                # We sum after which divide since torch.distributed doesn't have a median operation
                torch.distributed.all_reduce(param.grad.data, op=torch.distributed.ReduceOp.SUM)
                # Assuming each GPU received an equally sized mini-batch, we are able to average
                # the gradients dividing by the variety of GPUs (aka world size)
                # By default the loss function already averages over the mini-batch size
                param.grad.data /= torch.distributed.get_world_size()

Before we start training, we obviously need our model to be the identical across all GPUs, otherwise we could be training different models! Let’s improve our implementation by checking that every one weights are an identical during instantiation (when you don’t know what ranks are, check the first blog post of the series).

import torch


class DDPModelWrapper:
    def __init__(self, model: torch.nn.Module):
        self.model = model
        for param in self.model.parameters():
            # We create a brand new tensor so it will possibly receive the printed
            rank_0_param = param.data.clone()
            # Initially rank_0_param comprises the values for the present rank
            torch.distributed.broadcast(rank_0_param, src=0)
            # After the printed rank_0_param variable is overwritten with the parameters from rank_0
            if not torch.equal(param.data, rank_0_param):  # Now we compare rank_x with rank_0
                raise ValueError("Model parameters aren't the identical across all processes.")

    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def sync_gradients(self):
        for param in self.model.parameters():  
            if param.grad shouldn't be None:  
                torch.distributed.all_reduce(param.grad.data, op=torch.distributed.ReduceOp.SUM)
                param.grad.data /= torch.distributed.get_world_size()

Combining DDP with GA

You possibly can mix DDP with GA to realize even larger effective batch sizes. This is especially useful when your model is so large that only a couple of samples fit per GPU.

The important thing profit is reduced communication overhead: as an alternative of syncing gradients after every batch, you simply sync once per grad_accum_steps batches. This implies:

  • Global effective batch size = num_gpus × micro_batch_size × grad_accum_steps
  • Fewer synchronization points = less time spent on inter-GPU communication

A training loop using our DDPModelWrapper with Gradient Accumulation looks like this:

def training_loop(
    ddp_model: DDPModelWrapper,
    dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_fn: callable,
    grad_accum_steps: int,
):
    for i, batch in enumerate(dataloader):
        inputs, targets = batch
        output = ddp_model(inputs)
        loss = loss_fn(output, targets)
        loss.backward()

        if (i+1) % grad_accum_steps == 0:
            # Must sync gradients across GPUs *BEFORE* the optimization step
            ddp_model.sync_gradients()
            optimizer.step()
            optimizer.zero_grad()

Pro-tips and advanced usage

  • Use data prefetching. You possibly can speed up training by loading the subsequent batch of knowledge while the present one is being processed. PyTorch’s DataLoader provides a prefetch_factor argument that controls what number of batches to prefetch within the background. Properly leveraging prefetching with CUDA is usually a bit tricky, so we’ll leave it for a future post.
  • Don’t max out GPU memory. Counter-intuitively, leaving some free memory can result in faster training throughput. Once you leave a minimum of ~15% of GPU memory free, the GPU can higher manage memory by avoiding fragmentation.
  • PyTorch DDP overlaps communication with computation. By default, DDP communicates gradients as they’re computed during backpropagation somewhat than waiting for the complete backward pass to complete. Here’s how:
    • PyTorch organizes model gradients into buckets of bucket_cap_mb megabytes. Throughout the backward pass, PyTorch marks gradients as ready for reduction as they’re computed. Once all gradients in a bucket are ready, DDP kicks off an asynchronous allreduce to average those gradients across all ranks. The loss.backward() call returns only in any case allreduceoperations have accomplished, so immediately calling opt.step() is protected.
    • The bucket_cap_mb parameter creates a tradeoff: smaller values trigger more frequent allreduce operations, but each communication kernel launch incurs some overhead that may hurt performance. Larger values reduce communication frequency but in addition reduce overlap; at the intense, if buckets are too large, you’re waiting for the complete backward pass to complete before communicating. The optimal value is determined by your model architecture and hardware, so profile with different values to seek out what works best.
Source: PyTorch Tutorial
  • Here’s a whole PyTorch implementation of DDP:
"""
Launch with:
  torchrun --nproc_per_node=NUM_GPUS ddp.py
"""
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from torch import optim


class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1024, 1024), nn.ReLU(),
            nn.Linear(1024, 1024), nn.ReLU(),
            nn.Linear(1024, 256),
        )

    def forward(self, x):
        return self.net(x)


def train():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    # Create dummy dataset
    x_data = torch.randn(1000, 1024)
    y_data = torch.randn(1000, 256)
    dataset = TensorDataset(x_data, y_data)

    # DistributedSampler ensures each rank gets different data
    sampler = DistributedSampler(dataset, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)

    model = ToyModel().to(device)

    # gradient_as_bucket_view: avoids an additional grad tensor copy per bucket.
    ddp_model = DDP(
        model,
        device_ids=[rank],
        bucket_cap_mb=25,
        gradient_as_bucket_view=True,
    )

    optimizer = optim.AdamW(ddp_model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    for epoch in range(2):
        sampler.set_epoch(epoch)  # Ensures different shuffling each epoch

        for batch_idx, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            output = ddp_model(x)
            loss = loss_fn(output, y)

            # Backward mechanically overlaps with allreduce per bucket.
            # By the point this returns, all allreduce ops are done.
            loss.backward()
            optimizer.step()

            if rank == 0 and batch_idx % 5 == 0:
                print(f"epoch {epoch}  batch {batch_idx}  loss={loss.item():.4f}")

    dist.destroy_process_group()


if __name__ == "__main__":
    train()
  • Here’s a whole PyTorch implementation combining DDP with GA:
"""
Launch with:
  torchrun --nproc_per_node=NUM_GPUS ddp_ga.py
"""
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from torch import optim
from contextlib import nullcontext


class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1024, 1024), nn.ReLU(),
            nn.Linear(1024, 1024), nn.ReLU(),
            nn.Linear(1024, 256),
        )

    def forward(self, x):
        return self.net(x)


def train():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    # Create dummy dataset
    x_data = torch.randn(1000, 1024)
    y_data = torch.randn(1000, 256)
    dataset = TensorDataset(x_data, y_data)

    # DistributedSampler ensures each rank gets different data
    sampler = DistributedSampler(dataset, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)

    model = ToyModel().to(device)

    ddp_model = DDP(
        model,
        device_ids=[rank],
        bucket_cap_mb=25,
        gradient_as_bucket_view=True,
    )

    optimizer = optim.AdamW(ddp_model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    ACCUM_STEPS = 4

    for epoch in range(2):
        sampler.set_epoch(epoch)  # Ensures different shuffling each epoch

        optimizer.zero_grad()
        for batch_idx, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)

            is_last_micro_step = (batch_idx + 1) % ACCUM_STEPS == 0

            # no_sync() suppresses allreduce on accumulation steps.
            # On the last microstep we exit no_sync() so DDP fires
            # the allreduce overlapped with that backward pass.
            ctx = ddp_model.no_sync() if not is_last_micro_step else nullcontext()

            with ctx:
                output = ddp_model(x)
                loss = loss_fn(output, y) / ACCUM_STEPS
                loss.backward()

            if is_last_micro_step:
                optimizer.step()
                optimizer.zero_grad()

                if rank == 0:
                    print(f"epoch {epoch}  batch {batch_idx}  loss={loss.item() * ACCUM_STEPS:.4f}")

    dist.destroy_process_group()


if __name__ == "__main__":
    train()

Conclusion

Follow me on X for more free AI content @l_cesconetto

Congratulations on making it to the tip! On this post you learned about:

  • The importance of huge batch sizes
  • How Gradient Accumulation works and its limitations
  • The DDP workflow and its advantages
  • Easy methods to implement GA and DDP from scratch in PyTorch
  • Easy methods to mix GA and DDP

Within the next article, we’ll explore ZeRO (Zero Redundancy Optimizer), a more advanced technique that builds upon DDP to further optimize VRAM memory usage.

References

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x