or fine-tuned an LLM, you’ve likely hit a wall on the very last step: the Cross-Entropy Loss.
The offender is the logit bottleneck. To predict the subsequent token, we project a hidden state into a large vocabulary space. For Llama 3 (128,256 tokens), the load matrix alone is over 525 million parameters. While that’s only ~1GB in bfloat16, the intermediate logit tensor is the true issue. For giant batches, it could possibly easily exceed 80GB of VRAM simply to compute a single scalar loss.
Optimising this layer is how libraries like Unsloth and Liger-Kernel achieve such massive memory reductions. In this text, we’ll construct a fused Linear + Cross Entropy kernel from scratch in Triton. We’ll derive the maths and implement a tiled forward and backward pass that slashes peak memory usage by 84%.
Note on Performance: This implementation is primarily educational. We prioritise mathematical clarity and readable Triton code through the use of global atomic operations. While it solves the memory bottleneck, matching production-grade speeds would require significantly more complex implementations that are out of scope for this text.
This post is an element of my Triton series. We’ll be using concepts like tiling and online softmax that we’ve covered previously. If those sound unfamiliar, I like to recommend catching up there first!
The Logit Bottleneck
To get us began, let’s put some more numbers on the logit bottleneck. We consider an input matrix X with shape [NxD], a weight matrix W with shape [DxV] and a logit matrix Y=X@W with shape [NxV]. Within the context of an LLM, N can be the sequence length multiplied by the batch size (i.e. the entire variety of tokens within the batch), D the dimensions of the hidden state and V the vocabulary size.
For a Llama3 8B model, we might have a context window of 8192 tokens, a hidden state with 4096 dimensions and a vocabulary size of 128,256 tokens. Using a modest batch size of 8, we get N = 8192x8 = 65,536.
This leads to the Y matrix having shape [NxV]=[65,536x128,256], or roughly 8.4 billion elements. In bfloat16, this could take up 16.8GB of memory. Nonetheless, if we follow best practices and use float32 for the loss calculation to make sure numerical stability, the necessities double to 33.6GB.
To place this number in perspective, we might also need around 16GB of memory to carry the weights of Llama3 8B in memory in bfloat16. One most GPUs, this leaves no space for the huge overhead of the optimiser states (e.g. Adam’s moments) and other activations, leading to the infamous PyTorch OOM error.
Generally, this problem is handled through the use of:
- Gradient accumulation: Use a smaller batch size and accumulate gradients over multiple batches between each optimiser step, emulating a bigger batch size while holding less data in memory.
- Activation checkpointing: PyTorch stores all intermediate activations for reuse within the backward pass, checkpointing clears these activations and recomputes them on-the-fly throughout the backward pass. This results in large memory savings but increases training time for the reason that variety of required forward passes is doubled.
- Micro-batching the loss: As an alternative of computing the loss over the
Ndimension without delay, we will slice it and accumulate the loss over smaller chunks with sizen < N. Now, we only hold a slice of size[n, V]in memory at a time. - Mixed precision training: Using half precision during training provides 2x memory reduction and significant speedups on Tensor Cores.
While these solutions seem attractive, all of them have significant drawbacks: gradient accumulation and activation checkpointing decelerate training, mixed precision will be unstable and micro-batching requires (slow) PyTorch level iteration and regardless that n is chosen to be smaller than N, the vocabulary size stays huge as compared.
More importantly, these solutions don't address the issue we've handled repeatedly throughout this series: data movement. Indeed, we're still wasting time by writing billions of logits to VRAM only to read them back milliseconds later.
The Kernel Solution
As we’ll see in a minute, the forward and backward pass of the cross-entropy loss involve dot products, matrix multiplication and a softmax. As we learned on this series, these are all operations that will be tiled efficiently. In other words, we will perform them iteratively while only holding a small piece of the inputs in memory at any time.
Moreover, cross-entropy is usually preceded by a matrix multiplication: the linear projection from the hidden state into the vocabulary space. That is an excellent opportunity for operator fusion: fusing multiple operation inside a single kernel, leading to large speedups and potential memory gains.
In the next sections, we’ll take a have a look at learn how to derive and efficiently fuse the forward and backward passes through a kernel combining a linear layer with cross-entropy.

As mentioned within the last article, Triton kernels don't natively register in PyTorch’s autograd. Subsequently we'd like to derive the gradient ourselves, a beautiful occasion to brush up on some calculus 😉
The maths behind Fused Linear Cross-Entropy
Definition and Forward Pass
On this section, we derive the mathematical expression for our Fused Linear Cross-Entropy layer to see the way it naturally lends itself to tiling.
For 2 discrete probability distributions p and q, cross-entropy is defined as:

In our context, p is the one-hot vector representing the goal token, while q is the model’s distribution over the vocabulary. We obtain q by applying a softmax to the logits l, themselves the outputs of the preceding linear layer.
Since p is positive for a single goal token y, the summation collapses. We are able to then substitute the numerically stable softmax (as discussed within the last article) to derive the ultimate expression:

By substituting the logits l with the linear layer x . w, we see that the forward pass boils right down to three primary quantities:
- The goal logit
x . w_y. - The log-sum-exp (LSE) of all dot products.
- The worldwide maximum logit used for numerical stability.
Because of the web softmax algorithm, we will compute these quantities without ever materialising the complete vocabulary in memory. As an alternative of an O(V) memory bottleneck, we iterate over the hidden dimension D and the vocabulary V in small tiles (D_block and V_block). This transforms the calculation into an O(1) register problem.
To parallelise this effectively, we launch one GPU program per row of the input matrix. Each program independently executes the next steps:
- Pre-compute the goal logit: Perform a tiled dot product between the present row of
Xand the column ofWrelated to tokenY. - Online reduction: Iterate through the hidden and vocabulary blocks to:
1. Track the running maximum (m)
2. Update the running sum of exponentials (d) using the web softmax formula:


Now that we've a greater understanding of the forward pass, let’s take a have a look at the derivation of the backward pass.
Backward Pass
Notation
To derive our gradients efficiently, we’ll use Einstein notation and the Kronecker delta.
In Einstein notation, repeated indices are implicitly summed over. For instance, a typical matrix multiplication Y = X@W simplifies from a verbose summation to a clean index pairing:

The Kronecker delta (δ_ij) is used alongside this notation to handle identity logic. It is the same as 1 if i=j and 0 otherwise. As we’ll see, this is especially useful for collapsing indices during differentiation.
Matrix Multiplication
On this section, we derive the back-propagated gradients for matrix multiplication. We assume the existence of an upstream gradient ℓ.
To find out the way it back-propagates through matrix multiplication, we use the apply the chain rule to the inputs x and the load matrix w. Here y represents the multiplication’s outputs:

We start by deriving the partial derivatives of y with respect to x, following these steps:
- Express
yby way ofxandw. - Notice that
wis a relentless with respect to the derivative ofx, so we will pull it out of the derivative. - Express the proven fact that the partial derivative of
x_ikwith respect tox_mnis 1 only wheni=mandk=nusing the Kronecker delta. - Notice that
ẟ_knenforcesk=n, due to this factw_kj * ẟ_knreduces tow_nj.

Then, we consider the complete expression and acquire the gradient. We derive the last step by noticing once more that 1/y_ij * ẟ_im reduces to 1/y_mj.

Nonetheless, matrix notation is conceptually closer to our Triton kernel, due to this fact, we rewrite this expression as a matrix multiplication through the use of the identity X_ij = [X^T]_ji:

We follow the very same steps to derive the gradient with respect to W:

Then, the back-propagated gradient follows:

Which is akin to the matrix notation:

Cross-Entropy
On this section, we’ll concentrate on cross-entropy applied to discrete probability distributions. Considering a tensor of j logits, with a label y, the cross-entropy is computed as follows:

Where x_y corresponds to the logit associated to the label.
Once more, we're concerned about the partial derivative of any output i with respect to any input k. Due to normalising factor, every element i affects the worth of each other element, due to this fact, the partial derivative is obtained by defining the function piecewise depending on the worth of i:

Summing each cases, we obtain the gradient:

And in matrix notation:

Where y_{one hot} is a vector of zeros with the entry corresponding to the label set to at least one. This result tells us that the gradient is solely the difference between the prediction and the bottom truth.
Fused Linear Cross-Entropy
Combining the linear projection with cross-entropy in a single expression, we get:

Because of the chain rule, deriving the gradient of this expression boils right down to multiplying the gradients we computed previously:

Where x and y discuss with the inputs and outputs to the linear layer respectively and w to the associated weight matrix.
Note: in a batched setting, we’ll need to scale back the
Wgradients over the batch dimension. Generally, we use a sum or mean reduction.
Kernel Implementation
With the idea established, we will implement the fused kernel in Triton. Since cross-entropy is usually the ultimate layer in a language model, we will mix the forward and backward passes right into a . This fusion offers two benefits: it minimises the overhead of multiple kernel launches and significantly improves data locality by keeping intermediate values on-chip.
We'll analyse the kernel step-by-step from the attitude of a single program instance, which, in our parallelisation strategy, handles one specific row of the input matrix.
1. Setup and Goal Logit Pre-computation
The initial phase involves standard Triton setup:
- Program Identification: We use
tl.program_idto find out which row of the input matrix the present program is liable for. - Parameter Initialisation: We define tiles using
D_BLOCKandV_BLOCKand initialise the running maximum (m) and sum (d) required for the web softmax algorithm. - Pointer Arithmetic: We calculate the bottom memory addresses for our tensors. Pointers for
X(input) anddX(gradient) are offset using the row stride so each program accesses its unique token vector. Conversely, theW(weight) pointer stays at the bottom address because every program must eventually iterate through the whole vocabulary space. - Masking and Early Exit: We define an
ignore_index(defaulting to-100). If a program encounters this label (e.g. for padding tokens), it terminates early with a lack of 0 to save lots of cycles.
2. Computing the Goal Logit
Before the important loop, we must isolate the goal logit x . w_y. We iterate over the hidden dimension D in D_BLOCK chunks, performing a dot product between the input row X and the precise column of W corresponding to the ground-truth label Y.
Because W is a 2D matrix, calculating the pointers for these specific column tiles requires precise stride manipulation. The illustration below helps visualising how we “jump” through memory to extract only the essential weights for the goal token.

Once the tiles are loaded, we solid them to float32 to make sure numerical stability and add their dot product to an accumulator variable before moving to the subsequent iteration.
Here’s the code thus far:
Next, we execute the forward pass, which processes the vocabulary space in two nested stages:
- Tiled Logit Computation: We compute the logits for a
V_BLOCKat a time. That is achieved by iterating over vocabulary dimensionV(outer loop) and the hidden dimensionD(inner loop). Inside the inner loop, we load a tile ofXand a block ofW, accumulating their partial dot products right into a high-precision register. - Online Softmax Update: Once the complete dot product for a logit tile is finalised, we don’t store it to VRAM. As an alternative, we immediately update our running statistics: the utmost value
mand the running sum of exponentialsdusing the web softmax formula. By doing this “on the fly”, we be certain that we only ever hold a smallV_BLOCKof logits within the GPU’s registers at any given moment.
Following these iterations, the ultimate values of m and d are used to reconstruct the LSE. The ultimate scalar loss for the row is then computed by subtracting the goal logit (x . w_y) from this LSE value.
Here’s a visible representation of the forward pass:

Here’s the code for the forward pass:
We are actually right down to the last a part of the kernel: the backward pass. Our goal is to compute the gradients with respect to X and W using the expression we derived earlier:

To stay memory-efficient, we once more process the vocabulary in tiles using a two-staged approach:
- Recomputing Normalised Probabilities (
P): Because we didn’t store the complete logit matrix throughout the forward pass, we must recompute the activations for every tile. By reusing the Log-Sum-Exp calculated within the forward pass, we will normalise these activations on-the-fly. Subtracting the ground-truth labelYfrom the goal logit inside this tile gives us a neighborhood chunk of the gradient logit,P.
2. Gradient Accumulation: With a tile ofPin hand, we calculate the partial gradients. FordX, we perform a dot product with blocks ofW^T; fordW, we multiply by tiles ofX^T. To soundly aggregate these values across the whole batch, we use Triton’stl.atomic_add.
This operation acts as a thread-safe+=, ensuring that different programs updating the identical weight gradient don't overwrite each other.
Listed here are some additional details on the implementation:
- The Stride Swap: When computing
P . W_T, we don’t really want to physically transpose the hugeWmatrix in memory. As an alternative, we invert the shapes and strides inW’s block pointer to read the rows ofWas columns ofW^T. This leads to a “free” transpose that saves each time and VRAM. - Numerical Precision: It's price noting that while
XandWis perhaps inbfloat16, the buildup ofdWanddXviaatomic_addis normally performed in float32 to stop the buildup of tiny rounding errors across 1000's of rows. - Contention Note: While
atomic_addis essential fordW(because every program updates the identical weights),dXis private to every program, meaning there's zero contention between program IDs for that specific tensor. - Atomic Add Masking:
atomic_adddoesn’t support block pointers. Subsequently, we implement the pointer and mask logic fordWexplicitly.
The next figure is a representation of the backward pass for one iteration of the outer loop (i.e. one block along V and all blocks along D):

Here’s the complete code for the backward pass:
This concludes the implementation of our kernel! The total code including the kernel and benchmark script is offered here.
Memory Benchmark
Finally, we compare our kernel with the PyTorch baseline using hyperparameters inspired from Llama3 and an A100 GPU. Specifically, we consider a sequence length of S=16,384, a batch size of B=1 and an embedding dimension of D=4096; the vocabulary size is ready to V=128,256.
As expected, the PyTorch baseline allocates a large intermediate tensor to store the activations, leading to a peak memory usage of 36.02GB. Compared, our Triton kernel reduces the height memory usage by 84% by allocating only 5.04GB using D_BLOCK=64 and V_BLOCK=64!
Using even smaller block sizes would allow for further memory gains at the associated fee of efficiency.

Atomic Limitations and Production Scaling
In this text, we focused on the technical and mathematical intuition behind fused Linear Cross-Entropy kernels. We used atomic operations like tl.atomic_add to maintain the code minimal and readable. Nonetheless, while our kernel successfully slashed memory usage by a staggering 86%, the Triton kernel is significantly slower than native PyTorch.
Unfortunately, the identical atomic operations which make this kernel easier to put in writing and comprehend come at the associated fee of a large traffic jam since 1000's of threads try to change the identical memory address without delay. Generally, tl.atomic_add is performant when . In our current implementation, we've:
- High Contention: For the load gradient, each program within the batch (as much as
16,384in our test) is attempting to update the identical memory tiles concurrently. The hardware must serialise these updates, forcing 1000's of threads to attend in line. - Numerical Non-associativity: In computers, floating-point addition is non-associative. Rounding errors can accumulate otherwise depending on the order of operations, which is why correctness tests might pass on a T4 but fail on an A100, the latter has more streaming multiprocessors (SMs) performing more concurrent, non-deterministic additions.
Note on Precision: On Ampere and newer architectures, the
TF32format can further contribute to those discrepancies. For strict numerical parity, one should setallow_tf32=Falseor use higher precision types throughout the accumulation steps.
Path to Production
To maneuver beyond this educational implementation and toward a production-ready kernel (I like to recommend taking a look at the Liger-Kernel implementation), one could implement several optimisations:
- Replacing
dXAtomics: Since each program “owns” its row ofX, we will use easy register accumulation followed by atl.store, eliminating atomics for the input gradients entirely. - A dedicated
dWKernel: To optimise the computation ofdW, production kernels generally use a special grid strategy where each program handles a block ofWand iterates through the batch dimension, accumulating gradients locally before a single global write. - Micro-batching: Advanced implementations, akin to those within the Liger-Kernel library, process the sequence by blocks along the
Ndimension, making the memory scaling constant within the sequence length moderately than linear. This allows the use much larger batch sizes at a reduced memory cost.
Conclusion
This concludes our deep dive into fused linear cross-entropy kernels. Thanks for reading right through, and I hope this text gave you each the intuition and the sensible understanding needed to construct on these ideas and explore them further.
For those who found this handy, consider sharing the article; it genuinely helps support the effort and time that goes into producing this work. And as at all times, be happy to contact me if you have got questions, thoughts, or ideas for follow-ups.
Until next time! 👋
