Cutting LLM Memory by 84%: A Deep Dive into Fused Kernels

-

or fine-tuned an LLM, you’ve likely hit a wall on the very last step: the Cross-Entropy Loss.

The offender is the logit bottleneck. To predict the subsequent token, we project a hidden state into a large vocabulary space. For Llama 3 (128,256 tokens), the load matrix alone is over 525 million parameters. While that’s only ~1GB in bfloat16, the intermediate logit tensor is the true issue. For giant batches, it could possibly easily exceed 80GB of VRAM simply to compute a single scalar loss.

Optimising this layer is how libraries like Unsloth and Liger-Kernel achieve such massive memory reductions. In this text, we’ll construct a fused Linear + Cross Entropy kernel from scratch in Triton. We’ll derive the maths and implement a tiled forward and backward pass that slashes peak memory usage by 84%.

Note on Performance: This implementation is primarily educational. We prioritise mathematical clarity and readable Triton code through the use of global atomic operations. While it solves the memory bottleneck, matching production-grade speeds would require significantly more complex implementations that are out of scope for this text.

This post is an element of my Triton series. We’ll be using concepts like tiling and online softmax that we’ve covered previously. If those sound unfamiliar, I like to recommend catching up there first!

The Logit Bottleneck

To get us began, let’s put some more numbers on the logit bottleneck. We consider an input matrix X with shape [NxD], a weight matrix W with shape [DxV] and a logit matrix Y=X@W with shape [NxV]. Within the context of an LLM, N can be the sequence length multiplied by the batch size (i.e. the entire variety of tokens within the batch), D the dimensions of the hidden state and V the vocabulary size. 

For a Llama3 8B model, we might have a context window of 8192 tokens, a hidden state with 4096 dimensions and a vocabulary size of 128,256 tokens. Using a modest batch size of 8, we get N = 8192x8 = 65,536.

This leads to the Y matrix having shape [NxV]=[65,536x128,256], or roughly 8.4 billion elements. In bfloat16, this could take up 16.8GB of memory. Nonetheless, if we follow best practices and use float32 for the loss calculation to make sure numerical stability, the necessities double to 33.6GB.

To place this number in perspective, we might also need around 16GB of memory to carry the weights of Llama3 8B in memory in bfloat16. One most GPUs, this leaves no space for the huge overhead of the optimiser states (e.g. Adam’s moments) and other activations, leading to the infamous PyTorch OOM error.

Representation of the input, weight and logit matrices together with their memory footprint. (All illustrations and animations in this text were made by the creator unless specified otherwise)

Generally, this problem is handled through the use of:

  • Gradient accumulation: Use a smaller batch size and accumulate gradients over multiple batches between each optimiser step, emulating a bigger batch size while holding less data in memory.
  •  Activation checkpointing: PyTorch stores all intermediate activations for reuse within the backward pass, checkpointing clears these activations and recomputes them on-the-fly throughout the backward pass. This results in large memory savings but increases training time for the reason that variety of required forward passes is doubled.
  • Micro-batching the loss: As an alternative of computing the loss over the N dimension without delay, we will slice it and accumulate the loss over smaller chunks with size n < N. Now, we only hold a slice of size [n, V] in memory at a time.
  • Mixed precision training: Using half precision during training provides 2x memory reduction and significant speedups on Tensor Cores.

While these solutions seem attractive, all of them have significant drawbacks: gradient accumulation and activation checkpointing decelerate training, mixed precision will be unstable and micro-batching requires (slow) PyTorch level iteration and regardless that n is chosen to be smaller than N, the vocabulary size stays huge as compared.

More importantly, these solutions don't address the issue we've handled repeatedly throughout this series: data movement. Indeed, we're still wasting time by writing billions of logits to VRAM only to read them back milliseconds later.

The Kernel Solution

As we’ll see in a minute, the forward and backward pass of the cross-entropy loss involve dot products, matrix multiplication and a softmax. As we learned on this series, these are all operations that will be tiled efficiently. In other words, we will perform them iteratively while only holding a small piece of the inputs in memory at any time.

Moreover, cross-entropy is usually preceded by a matrix multiplication: the linear projection from the hidden state into the vocabulary space. That is an excellent opportunity for operator fusion: fusing multiple operation inside a single kernel, leading to large speedups and potential memory gains.

In the next sections, we’ll take a have a look at learn how to derive and efficiently fuse the forward and backward passes through a kernel combining a linear layer with cross-entropy.

As mentioned within the last article, Triton kernels don't natively register in PyTorch’s autograd. Subsequently we'd like to derive the gradient ourselves, a beautiful occasion to brush up on some calculus 😉

The maths behind Fused Linear Cross-Entropy

Definition and Forward Pass

On this section, we derive the mathematical expression for our Fused Linear Cross-Entropy layer to see the way it naturally lends itself to tiling.

For 2 discrete probability distributions p and q, cross-entropy is defined as:

In our context, p is the one-hot vector representing the goal token, while q is the model’s distribution over the vocabulary. We obtain q by applying a softmax to the logits l, themselves the outputs of the preceding linear layer.

Since p is positive for a single goal token y, the summation collapses. We are able to then substitute the numerically stable softmax (as discussed within the last article) to derive the ultimate expression:

By substituting the logits l with the linear layer x . w, we see that the forward pass boils right down to three primary quantities:

  1.  The goal logit x . w_y.
  2. The log-sum-exp (LSE) of all dot products.
  3. The worldwide maximum logit used for numerical stability.

Because of the web softmax algorithm, we will compute these quantities without ever materialising the complete vocabulary in memory. As an alternative of an O(V) memory bottleneck, we iterate over the hidden dimension D and the vocabulary V in small tiles (D_block and V_block). This transforms the calculation into an O(1) register problem.

To parallelise this effectively, we launch one GPU program per row of the input matrix. Each program independently executes the next steps:

  1. Pre-compute the goal logit: Perform a tiled dot product between the present row of X and the column of W related to token Y.
  2. Online reduction: Iterate through the hidden and vocabulary blocks to:
     1. Track the running maximum (m)
     2. Update the running sum of exponentials (d) using the web softmax formula:
An example of tiled matrix multiplication for a single GPU program processing a row of X. The colored squares represent elements loaded in memory and the colored outline represent the entire tile that's iterated on. Tiling trades off speed for large memory gains.

Now that we've a greater understanding of the forward pass, let’s take a have a look at the derivation of the backward pass.

Backward Pass

Notation

To derive our gradients efficiently, we’ll use Einstein notation and the Kronecker delta.

In Einstein notation, repeated indices are implicitly summed over. For instance, a typical matrix multiplication Y = X@W simplifies from a verbose summation to a clean index pairing:

The Kronecker delta (δ_ij) is used alongside this notation to handle identity logic. It is the same as 1 if i=j and 0 otherwise. As we’ll see, this is especially useful for collapsing indices during differentiation.

Matrix Multiplication

On this section, we derive the back-propagated gradients for matrix multiplication. We assume the existence of an upstream gradient

To find out the way it back-propagates through matrix multiplication, we use the apply the chain rule to the inputs x and the load matrix w. Here y represents the multiplication’s outputs:

We start by deriving the partial derivatives of y with respect to x, following these steps:

  1. Express y by way of x and w.
  2. Notice that w is a relentless with respect to the derivative of x, so we will pull it out of the derivative.
  3. Express the proven fact that the partial derivative of x_ik with respect to x_mn is 1 only when i=m and k=n using the Kronecker delta.
  4. Notice that ẟ_kn enforces k=n, due to this fact w_kj * ẟ_kn reduces to w_nj.

Then, we consider the complete expression and acquire the gradient. We derive the last step by noticing once more that 1/y_ij * ẟ_im reduces to 1/y_mj.

Nonetheless, matrix notation is conceptually closer to our Triton kernel, due to this fact, we rewrite this expression as a matrix multiplication through the use of the identity X_ij = [X^T]_ji:

We follow the very same steps to derive the gradient with respect to W:

Then, the back-propagated gradient follows:

Which is akin to the matrix notation:

Cross-Entropy

On this section, we’ll concentrate on cross-entropy applied to discrete probability distributions. Considering a tensor of j logits, with a label y, the cross-entropy is computed as follows:

Where x_y corresponds to the logit associated to the label.
Once more, we're concerned about the partial derivative of any output i with respect to any input k. Due to normalising factor, every element i affects the worth of each other element, due to this fact, the partial derivative is obtained by defining the function piecewise depending on the worth of i:

Summing each cases, we obtain the gradient:

And in matrix notation:

Where y_{one hot} is a vector of zeros with the entry corresponding to the label set to at least one. This result tells us that the gradient is solely the difference between the prediction and the bottom truth.

Fused Linear Cross-Entropy

Combining the linear projection with cross-entropy in a single expression, we get:

Because of the chain rule, deriving the gradient of this expression boils right down to multiplying the gradients we computed previously:

Where x and y discuss with the inputs and outputs to the linear layer respectively and w to the associated weight matrix.

Note: in a batched setting, we’ll need to scale back the W gradients over the batch dimension. Generally, we use a sum or mean reduction.

Kernel Implementation

With the idea established, we will implement the fused kernel in Triton. Since cross-entropy is usually the ultimate layer in a language model, we will mix the forward and backward passes right into a . This fusion offers two benefits: it minimises the overhead of multiple kernel launches and significantly improves data locality by keeping intermediate values on-chip.

We'll analyse the kernel step-by-step from the attitude of a single program instance, which, in our parallelisation strategy, handles one specific row of the input matrix.

1. Setup and Goal Logit Pre-computation

The initial phase involves standard Triton setup:

  • Program Identification: We use tl.program_id to find out which row of the input matrix the present program is liable for.
  • Parameter Initialisation: We define tiles using D_BLOCK and V_BLOCK and initialise the running maximum (m) and sum (d) required for the web softmax algorithm.
  • Pointer Arithmetic: We calculate the bottom memory addresses for our tensors. Pointers for X (input) and dX (gradient) are offset using the row stride so each program accesses its unique token vector. Conversely, the W (weight) pointer stays at the bottom address because every program must eventually iterate through the whole vocabulary space.
  • Masking and Early Exit: We define an ignore_index (defaulting to -100). If a program encounters this label (e.g. for padding tokens), it terminates early with a lack of 0 to save lots of cycles.

2. Computing the Goal Logit

Before the important loop, we must isolate the goal logit x . w_y. We iterate over the hidden dimension D in D_BLOCK chunks, performing a dot product between the input row X and the precise column of W corresponding to the ground-truth label Y.

Because W is a 2D matrix, calculating the pointers for these specific column tiles requires precise stride manipulation. The illustration below helps visualising how we “jump” through memory to extract only the essential weights for the goal token.

Representation of the pointer arithmetic executed to compute the goal logit Y. Here, we consider that the label is 4, meaning that the goal logit is X’s dot product with W’s fifth column. Vectors of various colors represent different steps of the iteration along D (i.e. different values of d_idx). Numbers discuss with the memory address of every element assuming a row-major layout.

Once the tiles are loaded, we solid them to float32 to make sure numerical stability and add their dot product to an accumulator variable before moving to the subsequent iteration.

Here’s the code thus far:

Next, we execute the forward pass, which processes the vocabulary space in two nested stages:

  1. Tiled Logit Computation: We compute the logits for a V_BLOCK at a time. That is achieved by iterating over vocabulary dimension V (outer loop) and the hidden dimension D (inner loop). Inside the inner loop, we load a tile of X and a block of W, accumulating their partial dot products right into a high-precision register.
  2. Online Softmax Update: Once the complete dot product for a logit tile is finalised, we don’t store it to VRAM. As an alternative, we immediately update our running statistics: the utmost value m and the running sum of exponentials d using the web softmax formula. By doing this “on the fly”, we be certain that we only ever hold a small V_BLOCK of logits within the GPU’s registers at any given moment.

Following these iterations, the ultimate values of m and d are used to reconstruct the LSE. The ultimate scalar loss for the row is then computed by subtracting the goal logit (x . w_y) from this LSE value.

Here’s a visible representation of the forward pass:

Visual representation of the tiled matrix multiplication with running statistics updates. At each step, we load elements colored in green or dark blue and compute the dot products of vectors highlighted in green. Elements of Y are accrued by iterating over the D dimension, when this is finished (i.e. the cells are green), we update m and d based on the freshly computed tile.

Here’s the code for the forward pass:

We are actually right down to the last a part of the kernel: the backward pass. Our goal is to compute the gradients with respect to X and W using the expression we derived earlier:

To stay memory-efficient, we once more process the vocabulary in tiles using a two-staged approach:

  1. Recomputing Normalised Probabilities (P): Because we didn’t store the complete logit matrix throughout the forward pass, we must recompute the activations for every tile. By reusing the Log-Sum-Exp calculated within the forward pass, we will normalise these activations on-the-fly. Subtracting the ground-truth label Y from the goal logit inside this tile gives us a neighborhood chunk of the gradient logit, P.
    2. Gradient Accumulation: With a tile of P in hand, we calculate the partial gradients. For dX, we perform a dot product with blocks of W^T; for dW, we multiply by tiles of X^T. To soundly aggregate these values across the whole batch, we use Triton’s tl.atomic_add.
    This operation acts as a thread-safe +=, ensuring that different programs updating the identical weight gradient don't overwrite each other.

Listed here are some additional details on the implementation:

  • The Stride Swap: When computing P . W_T, we don’t really want to physically transpose the huge W matrix in memory. As an alternative, we invert the shapes and strides in W’s block pointer to read the rows of W as columns of W^T. This leads to a “free” transpose that saves each time and VRAM.
  • Numerical Precision: It's price noting that while X and W is perhaps in bfloat16, the buildup of dW and dX via atomic_add is normally performed in float32 to stop the buildup of tiny rounding errors across 1000's of rows.
  • Contention Note: While atomic_add is essential for dW (because every program updates the identical weights), dX is private to every program, meaning there's zero contention between program IDs for that specific tensor.
  • Atomic Add Masking: atomic_add doesn’t support block pointers. Subsequently, we implement the pointer and mask logic for dW explicitly.

The next figure is a representation of the backward pass for one iteration of the outer loop (i.e. one block along V and all blocks along D):

Representation of the backward pass for a single step along the V dimension and a full iteration along the D dimension. In stage 4, we highlight how dX is accrued over (every program updates its private row once per step along V) whereas dW is accrued over programs (N programs update the values of a single block in dW at every step along V).

Here’s the complete code for the backward pass:

This concludes the implementation of our kernel! The total code including the kernel and benchmark script is offered here.

Memory Benchmark

Finally, we compare our kernel with the PyTorch baseline using hyperparameters inspired from Llama3 and an A100 GPU. Specifically, we consider a sequence length of S=16,384, a batch size of B=1 and an embedding dimension of D=4096; the vocabulary size is ready to V=128,256.

As expected, the PyTorch baseline allocates a large intermediate tensor to store the activations, leading to a peak memory usage of 36.02GB. Compared, our Triton kernel reduces the height memory usage by 84% by allocating only 5.04GB using D_BLOCK=64 and V_BLOCK=64!

Using even smaller block sizes would allow for further memory gains at the associated fee of efficiency.

Atomic Limitations and Production Scaling

In this text, we focused on the technical and mathematical intuition behind fused Linear Cross-Entropy kernels. We used atomic operations like tl.atomic_add to maintain the code minimal and readable. Nonetheless, while our kernel successfully slashed memory usage by a staggering 86%, the Triton kernel is significantly slower than native PyTorch.

Unfortunately, the identical atomic operations which make this kernel easier to put in writing and comprehend come at the associated fee of a large traffic jam since 1000's of threads try to change the identical memory address without delay. Generally, tl.atomic_add is performant when . In our current implementation, we've:

  1. High Contention: For the load gradient, each program within the batch (as much as 16,384 in our test) is attempting to update the identical memory tiles concurrently. The hardware must serialise these updates, forcing 1000's of threads to attend in line.
  2. Numerical Non-associativity: In computers, floating-point addition is non-associative. Rounding errors can accumulate otherwise depending on the order of operations, which is why correctness tests might pass on a T4 but fail on an A100, the latter has more streaming multiprocessors (SMs) performing more concurrent, non-deterministic additions.

Note on Precision: On Ampere and newer architectures, the TF32 format can further contribute to those discrepancies. For strict numerical parity, one should set allow_tf32=False or use higher precision types throughout the accumulation steps.

Path to Production

To maneuver beyond this educational implementation and toward a production-ready kernel (I like to recommend taking a look at the Liger-Kernel implementation), one could implement several optimisations:

  • Replacing dX Atomics: Since each program “owns” its row of X, we will use easy register accumulation followed by a tl.store, eliminating atomics for the input gradients entirely.
  • A dedicated dW Kernel: To optimise the computation of dW, production kernels generally use a special grid strategy where each program handles a block of W and iterates through the batch dimension, accumulating gradients locally before a single global write.
  • Micro-batching: Advanced implementations, akin to those within the Liger-Kernel library, process the sequence by blocks along the N dimension, making the memory scaling constant within the sequence length moderately than linear. This allows the use much larger batch sizes at a reduced memory cost.

Conclusion

This concludes our deep dive into fused linear cross-entropy kernels. Thanks for reading right through, and I hope this text gave you each the intuition and the sensible understanding needed to construct on these ideas and explore them further.

For those who found this handy, consider sharing the article; it genuinely helps support the effort and time that goes into producing this work. And as at all times, be happy to contact me if you have got questions, thoughts, or ideas for follow-ups.

Until next time! 👋

Sources

  1. Introducing Meta Llama 3: Essentially the most capable openly available LLM thus far
  2. LigerKernel (lecture)
  3. LigerKernel Linear Cross-Entropy Implementation
  4. Unsloth Implementation (cross-entropy only)
ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x