of a series about distributed AI across multiple GPUs:
Introduction
Within the previous post, we saw how Distributed Data Parallelism (DDP) hastens training by splitting batches across GPUs. DDP solves the throughput problem, however it introduces a brand new challenge:Â memory redundancy.
In vanilla DDP, every GPU holds an entire copy of the model parameters, gradients, and optimizer states. For big models like GPT-3 (175B parameters), this redundancy becomes a giant waste of precious VRAM.
ZeRO (Zero Redundancy Optimizer) solves this. There are three levels:
- ZeRO-1Â partitions only optimizer states
- ZeRO-2Â partitions optimizer states + gradients
- ZeRO-3Â partitions optimizer states + gradients + model parameters
ZeRO isn’t a parallelism technique because all GPUs still run the identical forward and backward passes. It’s a memory optimization strategy that eliminates redundancy across GPUs, letting you train larger models on the identical hardware.
The Memory Problem in DDP
Let’s break down what actually consumes memory during training. For a model with  parameters:
- Model Parameters:Â Â values (the weights of your neural network)
- Gradients:Â Â values (one gradient per parameter)
- Optimizer States (Adam):  values (first moment  and second moment  for every parameter)
- Activations: Intermediate outputs stored during forward pass to be used in backward pass
The primary three scale with model size and are redundant across GPUs in DDP. Activations scale with batch size, sequence length, and # neurons, and are unique per GPU since each GPU processes different data. ZeRO doesn’t touch activation memory.
Let’s calculate the memory usage for a 7B-parameter model using Adam and FP32:
- Parameters:Â 7 billion * 4 bytes = 28 GB
- Gradients:Â 7 billion * 4 bytes = 28 GB
- Optimizer states:Â 7 billion * 2 * 4 bytes = 56 GB
- Memory per GPU in DDP:Â Â 112 GB
Activations add significant memory on top of this, but since they’re unique per GPU, ZeRO can’t partition them. Techniques like activation checkpointing might help, it discards some activations after which recomputes them as needed through the backward pass. But that’s outside the scope of this text.
Let’s understand how ZeRO works by implementing it from the bottom up, starting with ZeRO-1 and dealing our technique to ZeRO-3.
ZeRO-1: Optimizer State Partitioning
In ZeRO-1, only the optimizer states are partitioned. Each GPU:
- Still holds the full model parameters and gradients
- Stores only 1/N of the optimizer states (N = variety of GPUs)
- Updates only the corresponding 1/N of the parameters
That is the sequence actions taken during training:
- Forward pass:Â each GPU processes its own micro-batch
- Backward pass:Â compute gradients
all-reduce gradients: every GPU gets the all gradients- Optimizer step: Each GPU updates its parameter partition
all-gather parameters: sync the updated model across GPUs

Here’s a simplified implementation:
import torch
import torch.distributed as dist
class ZeRO_1:
def __init__(self, model, optimizer_cls):
self.model = model
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_shards = list() # each rank holds only its shard of the optimizer states
self.param_metadata = list() # metadata to reconstruct shards
for param in self.model.parameters():
original_shape = param.data.shape
flat = param.data.view(-1)
numel = flat.numel()
remainder = numel % self.world_size
pad_size = (self.world_size - remainder) % self.world_size
padded_numel = numel + pad_size
shard_size = padded_numel // self.world_size
shard_start = self.rank * shard_size
shard_end = shard_start + shard_size
self.param_metadata.append(
{
"original_shape": original_shape,
"numel": numel,
"padded_numel": padded_numel,
"shard_size": shard_size,
"shard_start": shard_start,
"shard_end": shard_end,
}
)
if pad_size > 0:
flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
else:
flat_padded = flat
shard = flat_padded[shard_start:shard_end].clone()
shard.requires_grad_(True)
self.param_shards.append(shard)
self.optimizer = optimizer_cls(self.param_shards)
def training_step(self, inputs, targets, loss_fn):
output = self.model(inputs) # forward
loss = loss_fn(output, targets) # compute loss
loss.backward() # backward
self._sync_gradients() # all-reduce gradients across GPUs
self.optimizer.step() # update local shard of parameters
self._sync_params() # all gather model params
# clear gradients for the subsequent step
for param in self.model.parameters():
param.grad = None
def _sync_gradients(self):
for idx, param in enumerate(self.model.parameters()):
meta = self.param_metadata[idx]
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad /= self.world_size
self.param_shards[idx].grad = param.grad.view(-1)[meta["shard_start"]:meta["shard_end"]]
def _sync_params(self):
for idx, param in enumerate(self.model.parameters()):
meta = self.param_metadata[idx]
full_flat = torch.empty(meta["padded_numel"], device=param.device, dtype=param.dtype)
dist.all_gather_into_tensor(
output_tensor=full_flat,
input_tensor=self.param_shards[idx].data,
)
reconstructed = full_flat[:meta["numel"]].view(meta["original_shape"])
param.data.copy_(reconstructed)
Notice that the all-reduce syncs  gradients, but each GPU only uses the gradients for its own parameter partition, it’s overcommunicating. ZeRO-2 fixes this by sharding the gradients too.
In practice, you’d never use ZeRO-1 as ZeRO-2 gives you higher memory savings at essentially the identical cost. However it’s still value going over it for learning purposes.
Memory with ZeRO-1, 7B model, 8 GPUs:
- Parameters:Â 28 GB (fully replicated)
- Gradients:Â 28Â GB (fully replicated)
- Optimizer states:Â 56 GB / 8 = 7 GB
- Total per GPU: 63 GB (down from  GB)
ZeRO-2: Gradient Partitioning
ZeRO-2 partitions each optimizer states and gradients. Since each GPU only updates a partition of parameters, it only needs the corresponding gradients.
ZeRO-1 uses all-reduce, which provides every GPU all of the gradients. ZeRO-2 replaces this with reduce-scatter, each GPU receives only the gradients it actually needs. This protects each memory and communication bandwidth.
Training steps:
- Forward pass:Â each GPU processes its own micro-batch
- Backward pass:Â compute gradients
reduce-scatter gradients: each GPU gets only its partition- Optimizer step: Each GPU updates its parameter partition
all-gather parameters: sync the updated model across GPUs

The implementation could be very much like ZeRO-1, however the gradient synchronization step uses reduce-scatter as a substitute of all-reduce:
But wait, if every GPU computes all gradients during backprop, how does this actually save VRAM? Here’s how:
- Because the parameter gradients are computed layer by layer, they’re immediatelyÂ
reduce-scattered and the local copy is freed (our simplified implementation doesn’t perform this). - During backprop, you simply need the gradient of the subsequent neuron activation to compute the present param’s gradient, i.e., you don’t need all the gradient graph.
- That way you may liberate the memory for gradients as you’re moving backwards, keeping only the assigned partition for every GPU.
Memory with ZeRO-2, 7B model, 8 GPUs:
- Parameters:Â 28Â GB (fully replicated)
- Gradients:Â 28 GB / 8 = 3.5 GB
- Optimizer states:Â 56 GB / 8 = 7 GB
- Total per GPU:Â 38.5Â GBÂ (down from 112 GB)
ZeRO-3: Parameter Partitioning
ZeRO-3 partitions optimizer states, gradients, and parameters. Each GPU stores only one/N of all the model state.
During forward and backward passes, each layer needs its full parameters, but each GPU only stores a fraction. So we all-gather parameters just-in-time, use them, then discard immediately after.
Training steps:
- Forward pass:
- All-gather the layer’s parameters from all GPUs
- Run the layer’s forward pass using previous layer’s activations as input
- Discard the gathered parameters (keep only the local partition)
- Repeat these steps until all layers are done
- Backward pass (per layer, in reverse):
- All-gather the layer’s parameters again
- Compute gradients for current layer using activation gradients from next layer
- Reduce-scatter the gradients (each GPU keeps its shard)
- Discard the gathered parameters (keep only the local partition)
- Repeat these steps until all layers are done
- Each GPU runs an optimizer step on its partition
- No final all-gather needed since parameters are gathered layer-by-layer through the forward pass

Here’s a simplified implementation:
class ZeRO_3(ZeRO_2):
"""
ZeRO-3: Shard optimizer states (stage 1) + gradients (stage 2) + model parameters (stage 3).
At rest, each rank holds only param_shards[idx] — a 1/world_size slice
of every parameter. Full parameters are materialised temporarily during
the forward and backward passes via all_gather, then immediately freed.
"""
def __init__(self, model, optimizer_cls):
self.model = model
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_metadata = []
shard_list = []
self._param_to_idx = {}
for idx, param in enumerate(self.model.parameters()):
original_shape = param.data.shape
flat = param.data.view(-1)
numel = flat.numel()
remainder = numel % self.world_size
pad_size = (self.world_size - remainder) % self.world_size
padded_numel = numel + pad_size
shard_size = padded_numel // self.world_size
shard_start = self.rank * shard_size
shard_end = shard_start + shard_size
self.param_metadata.append(
{
"original_shape": original_shape,
"numel": numel,
"padded_numel": padded_numel,
"shard_size": shard_size,
"shard_start": shard_start,
"shard_end": shard_end,
}
)
if pad_size > 0:
flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
else:
flat_padded = flat
shard = flat_padded[shard_start:shard_end].clone()
shard_list.append(shard)
# Replace the complete tensor with only this rank's shard.
# The model's param.data now points to a tiny slice; the complete
# weight will likely be reconstructed on demand during forward/backward.
param.data = shard.detach()
self._param_to_idx[param] = idx
self.param_shards = [s.requires_grad_(True) for s in shard_list]
self.optimizer = optimizer_cls(self.param_shards)
self._register_hooks()
def _gather_param(self, idx, device, dtype):
"""All-gather the complete parameter tensor for parameter `idx`."""
meta = self.param_metadata[idx]
full_flat = torch.empty(meta["padded_numel"], device=device, dtype=dtype)
dist.all_gather_into_tensor(
output_tensor=full_flat,
input_tensor=self.param_shards[idx].data,
)
return full_flat[: meta["numel"]].view(meta["original_shape"])
def _gather_module_params(self, module):
"""Gather full params for each parameter that belongs to this module only (not children)."""
for param in module.parameters(recurse=False):
idx = self._param_to_idx[param]
param.data = self._gather_param(idx, param.device, param.dtype)
def _reshard_module_params(self, module):
"""Reshard params back to local shard for each direct param of this module."""
for param in module.parameters(recurse=False):
idx = self._param_to_idx[param]
param.data = self.param_shards[idx].data
def _register_hooks(self):
self._hooks = []
for module in self.model.modules():
# Skip container modules that haven't any direct parameters
if not list(module.parameters(recurse=False)):
proceed
# Forward: gather -> run -> reshard
h1 = module.register_forward_pre_hook(
lambda mod, _inputs: self._gather_module_params(mod)
)
h2 = module.register_forward_hook(
lambda mod, _inputs, _output: self._reshard_module_params(mod)
)
# Backward: gather before grad computation → reshard after
h3 = module.register_full_backward_pre_hook(
lambda mod, _grad_output: self._gather_module_params(mod)
)
h4 = module.register_full_backward_hook(
lambda mod, _grad_input, _grad_output: self._reshard_module_params(mod)
)
self._hooks.extend([h1, h2, h3, h4])
def training_step(self, inputs, targets, loss_fn):
# Hooks handle all gather/reshard around each module robotically
output = self.model(inputs)
loss = loss_fn(output, targets)
loss.backward()
self._sync_gradients()
# Each rank updates only its local shard
self.optimizer.step()
for param in self.model.parameters():
param.grad = None
Each layer’s parameters are gathered right before they’re needed and freed immediately after. This keeps peak memory minimal at the associated fee of more communication. In practice, implementations overlap the all-gather for layer N+1 with the forward of layer N to cover latency.
Memory with ZeRO-3, 7B model, 8 GPUs:
- Parameters:Â 28 GB / 8 = 3.5 GB
- Gradients:Â 28 GB / 8 = 3.5 GB
- Optimizer states:Â 56 GB / 8 = 7 GB
- Total per GPU: 14 GBÂ (down from 112 GB)
That’s an 8x reduction in memory usage, which is strictly what we’d expect from partitioning across 8 GPUs.
Using ZeRO in PyTorch
PyTorch ships with two implementations of ZeRO-3: FSDP1 (older, less optimized) and FSDP2 (newer, really useful). At all times use FSDP2.
FSDP (Fully Sharded Data Parallel) handles parameter gathering, gradient scattering, communication overlap, and memory management robotically:
from torch.distributed.fsdp import fully_shard
model = Transformer()
for layer in model.layers:
fully_shard(layer)
fully_shard(model)
You could have to use fully_shard layer-by-layer after which wrap the entire model.
Conclusion
ZeRO is exchanging memory for communication, so it’s not a free lunch. Basically it’s not value it for smaller models (e.g. BERT) however it’s a game changer for larger models.
Congratulations on making it to the tip! On this post, you learned about:
- The memory redundancy problem in standard DDP
- How ZeRO partitions optimizer states, gradients, and parameters across GPUs
- The three stages of ZeRO and their memory/communication trade-offs
- The best way to use ZeRO-3 via PyTorch’s FSDP
In the subsequent article, we’ll explore Tensor Parallelism, a model parallelism technique that hastens a layer computation by distributing work across GPUs.
