or fine-tuned an LLM, you’ve likely hit a wall at the very last step: the Cross-Entropy Loss.
The culprit is the logit bottleneck. To predict the next token, we project a hidden state into a massive vocabulary space. For Llama 3 (128,256 tokens), the weight matrix alone is over 525 million parameters. While that’s only ~1GB in bfloat16, the intermediate logit tensor is the real issue. For large batches, it can easily exceed 80GB of VRAM just to compute a single scalar loss.
Optimising this layer is how libraries like Unsloth and Liger-Kernel achieve such massive memory reductions. In this article, we’ll build a fused Linear + Cross Entropy kernel from scratch in Triton. We will derive the math and implement a tiled forward and backward pass that slashes peak memory usage by 84%.
Note on Performance: This implementation is primarily educational. We prioritise mathematical clarity and readable Triton code by using global atomic operations. While it solves the memory bottleneck, matching production-grade speeds would require significantly more complex implementations which are out of scope for this article.
This post is part of my Triton series. We’ll be using concepts like tiling and online softmax that we’ve covered previously. If those sound unfamiliar, I recommend catching up there first!
The Logit Bottleneck
To get us started, let’s put some more numbers on the logit bottleneck. We consider an input matrix X with shape [NxD], a weight matrix W with shape [DxV] and a logit matrix Y=X@W with shape [NxV]. In the context of an LLM, N would be the sequence length multiplied by the batch size (i.e. the total number of tokens in the batch), D the size of the hidden state and V the vocabulary size.
For a Llama3 8B model, we would have a context window of 8192 tokens, a hidden state with 4096 dimensions and a vocabulary size of 128,256 tokens. Using a modest batch size of 8, we get N = 8192x8 = 65,536.
This results in the Y matrix having shape [NxV]=[65,536x128,256], or roughly 8.4 billion elements. In bfloat16, this would take up 16.8GB of memory. However, if we follow best practices and use float32 for the loss calculation to ensure numerical stability, the requirements double to 33.6GB.
To put this number in perspective, we would also need around 16GB of memory to hold the weights of Llama3 8B in memory in bfloat16. One most GPUs, this leaves no space for the massive overhead of the optimiser states (e.g. Adam’s moments) and other activations, resulting in the infamous PyTorch OOM error.
Generally, this problem is dealt with by using:
- Gradient accumulation: Use a smaller batch size and accumulate gradients over multiple batches between each optimiser step, emulating a larger batch size while holding less data in memory.
- Activation checkpointing: PyTorch stores all intermediate activations for reuse in the backward pass, checkpointing clears these activations and recomputes them on-the-fly during the backward pass. This leads to large memory savings but increases training time since the number of required forward passes is doubled.
- Micro-batching the loss: Instead of computing the loss over the
Ndimension at once, we can slice it and accumulate the loss over smaller chunks with sizen < N. Now, we only hold a slice of size[n, V]in memory at a time. - Mixed precision training: Using half precision during training provides 2x memory reduction and significant speedups on Tensor Cores.
While these solutions seem attractive, they all have significant drawbacks: gradient accumulation and activation checkpointing slow down training, mixed precision can be unstable and micro-batching requires (slow) PyTorch level iteration and even though n is chosen to be smaller than N, the vocabulary size remains huge in comparison.
More importantly, these solutions do not address the problem we have dealt with repeatedly throughout this series: data movement. Indeed, we are still wasting time by writing billions of logits to VRAM only to read them back milliseconds later.
The Kernel Solution
As we’ll see in a minute, the forward and backward pass of the cross-entropy loss involve dot products, matrix multiplication and a softmax. As we learned in this series, these are all operations that can be tiled efficiently. In other words, we can perform them iteratively while only holding a small piece of the inputs in memory at any time.
Furthermore, cross-entropy is generally preceded by a matrix multiplication: the linear projection from the hidden state into the vocabulary space. This is a great opportunity for operator fusion: fusing multiple operation within a single kernel, resulting in large speedups and potential memory gains.
In the following sections, we’ll take a look at how to derive and efficiently fuse the forward and backward passes through a kernel combining a linear layer with cross-entropy.

As mentioned in the last article, Triton kernels do not natively register in PyTorch’s autograd. Therefore we need to derive the gradient ourselves, a wonderful occasion to brush up on some calculus 😉
The math behind Fused Linear Cross-Entropy
Definition and Forward Pass
In this section, we derive the mathematical expression for our Fused Linear Cross-Entropy layer to see how it naturally lends itself to tiling.
For two discrete probability distributions p and q, cross-entropy is defined as:

In our context, p is the one-hot vector representing the target token, while q is the model’s distribution over the vocabulary. We obtain q by applying a softmax to the logits l, themselves the outputs of the preceding linear layer.
Since p is positive for a single target token y, the summation collapses. We can then substitute the numerically stable softmax (as discussed in the last article) to derive the final expression:

By substituting the logits l with the linear layer x . w, we see that the forward pass boils down to three primary quantities:
- The target logit
x . w_y. - The log-sum-exp (LSE) of all dot products.
- The global maximum logit used for numerical stability.
Thanks to the online softmax algorithm, we can compute these quantities without ever materialising the full vocabulary in memory. Instead of an O(V) memory bottleneck, we iterate over the hidden dimension D and the vocabulary V in small tiles (D_block and V_block). This transforms the calculation into an O(1) register problem.
To parallelise this effectively, we launch one GPU program per row of the input matrix. Each program independently executes the following steps:
- Pre-compute the target logit: Perform a tiled dot product between the current row of
Xand the column ofWassociated with tokenY. - Online reduction: Iterate through the hidden and vocabulary blocks to:
1. Track the running maximum (m)
2. Update the running sum of exponentials (d) using the online softmax formula:


Now that we have a better understanding of the forward pass, let’s take a look at the derivation of the backward pass.
Backward Pass
Notation
To derive our gradients efficiently, we’ll use Einstein notation and the Kronecker delta.
In Einstein notation, repeated indices are implicitly summed over. For example, a standard matrix multiplication Y = X@W simplifies from a verbose summation to a clean index pairing:

The Kronecker delta (δ_ij) is used alongside this notation to handle identity logic. It is equal to 1 if i=j and 0 otherwise. As we’ll see, this is particularly useful for collapsing indices during differentiation.
Matrix Multiplication
In this section, we derive the back-propagated gradients for matrix multiplication. We assume the existence of an upstream gradient ℓ.
To determine how it back-propagates through matrix multiplication, we use the apply the chain rule to the inputs x and the weight matrix w. Here y represents the multiplication’s outputs:

We start by deriving the partial derivatives of y with respect to x, following these steps:
- Express
yin terms ofxandw. - Notice that
wis a constant with respect to the derivative ofx, so we can pull it out of the derivative. - Express the fact that the partial derivative of
x_ikwith respect tox_mnis 1 only wheni=mandk=nusing the Kronecker delta. - Notice that
ẟ_knenforcesk=n, thereforew_kj * ẟ_knreduces tow_nj.

Then, we consider the full expression and obtain the gradient. We derive the last step by noticing once again that 1/y_ij * ẟ_im reduces to 1/y_mj.

However, matrix notation is conceptually closer to our Triton kernel, therefore, we rewrite this expression as a matrix multiplication by using the identity X_ij = [X^T]_ji:

We follow the exact same steps to derive the gradient with respect to W:

Then, the back-propagated gradient follows:

Which is equivalent to the matrix notation:

Cross-Entropy
In this section, we’ll focus on cross-entropy applied to discrete probability distributions. Considering a tensor of j logits, with a label y, the cross-entropy is computed as follows:

Where x_y corresponds to the logit associated to the label.
Once again, we are interested in the partial derivative of any output i with respect to any input k. Because of the normalising factor, every element i affects the value of every other element, therefore, the partial derivative is obtained by defining the function piecewise depending on the value of i:

Summing both cases, we obtain the gradient:

And in matrix notation:

Where y_{one hot} is a vector of zeros with the entry corresponding to the label set to one. This result tells us that the gradient is simply the difference between the prediction and the ground truth.
Fused Linear Cross-Entropy
Combining the linear projection with cross-entropy in a single expression, we get:

Thanks to the chain rule, deriving the gradient of this expression boils down to multiplying the gradients we computed previously:

Where x and y refer to the inputs and outputs to the linear layer respectively and w to the associated weight matrix.
Note: in a batched setting, we’ll need to reduce the
Wgradients over the batch dimension. Generally, we use a sum or mean reduction.
Kernel Implementation
With the theory established, we can implement the fused kernel in Triton. Since cross-entropy is typically the final layer in a language model, we can combine the forward and backward passes into a single kernel. This fusion offers two advantages: it minimises the overhead of multiple kernel launches and significantly improves data locality by keeping intermediate values on-chip.
We will analyse the kernel step-by-step from the perspective of a single program instance, which, in our parallelisation strategy, handles one specific row of the input matrix.
1. Setup and Target Logit Pre-computation
The initial phase involves standard Triton setup:
- Program Identification: We use
tl.program_idto determine which row of the input matrix the current program is responsible for. - Parameter Initialisation: We define tiles using
D_BLOCKandV_BLOCKand initialise the running maximum (m) and sum (d) required for the online softmax algorithm. - Pointer Arithmetic: We calculate the base memory addresses for our tensors. Pointers for
X(input) anddX(gradient) are offset using the row stride so each program accesses its unique token vector. Conversely, theW(weight) pointer remains at the base address because every program must eventually iterate through the entire vocabulary space. - Masking and Early Exit: We define an
ignore_index(defaulting to-100). If a program encounters this label (e.g. for padding tokens), it terminates early with a loss of 0 to save cycles.
2. Computing the Target Logit
Before the main loop, we must isolate the target logit x . w_y. We iterate over the hidden dimension D in D_BLOCK chunks, performing a dot product between the input row X and the specific column of W corresponding to the ground-truth label Y.
Because W is a 2D matrix, calculating the pointers for these specific column tiles requires precise stride manipulation. The illustration below helps visualising how we “jump” through memory to extract only the necessary weights for the target token.

Once the tiles are loaded, we cast them to float32 to ensure numerical stability and add their dot product to an accumulator variable before moving to the next iteration.
Here’s the code so far:
Next, we execute the forward pass, which processes the vocabulary space in two nested stages:
- Tiled Logit Computation: We compute the logits for a
V_BLOCKat a time. This is achieved by iterating over vocabulary dimensionV(outer loop) and the hidden dimensionD(inner loop). Within the inner loop, we load a tile ofXand a block ofW, accumulating their partial dot products into a high-precision register. - Online Softmax Update: Once the full dot product for a logit tile is finalised, we don’t store it to VRAM. Instead, we immediately update our running statistics: the maximum value
mand the running sum of exponentialsdusing the online softmax formula. By doing this “on the fly”, we ensure that we only ever hold a smallV_BLOCKof logits in the GPU’s registers at any given moment.
Following these iterations, the final values of m and d are used to reconstruct the LSE. The final scalar loss for the row is then computed by subtracting the target logit (x . w_y) from this LSE value.
Here’s a visual representation of the forward pass:

Here’s the code for the forward pass:
We are now down to the last part of the kernel: the backward pass. Our goal is to compute the gradients with respect to X and W using the expression we derived earlier:

To remain memory-efficient, we once again process the vocabulary in tiles using a two-staged approach:
- Recomputing Normalised Probabilities (
P): Because we didn’t store the full logit matrix during the forward pass, we must recompute the activations for each tile. By reusing the Log-Sum-Exp calculated in the forward pass, we can normalise these activations on-the-fly. Subtracting the ground-truth labelYfrom the target logit within this tile gives us a local chunk of the gradient logit,P.
2. Gradient Accumulation: With a tile ofPin hand, we calculate the partial gradients. FordX, we perform a dot product with blocks ofW^T; fordW, we multiply by tiles ofX^T. To safely aggregate these values across the entire batch, we use Triton’stl.atomic_add.
This operation acts as a thread-safe+=, ensuring that different programs updating the same weight gradient do not overwrite one another.
Here are some additional details on the implementation:
- The Stride Swap: When computing
P . W_T, we don’t actually need to physically transpose the massiveWmatrix in memory. Instead, we invert the shapes and strides inW’s block pointer to read the rows ofWas columns ofW^T. This results in a “free” transpose that saves both time and VRAM. - Numerical Precision: It is worth noting that while
XandWmight be inbfloat16, the accumulation ofdWanddXviaatomic_addis usually performed in float32 to prevent the accumulation of tiny rounding errors across thousands of rows. - Contention Note: While
atomic_addis necessary fordW(because every program updates the same weights),dXis private to each program, meaning there is zero contention between program IDs for that specific tensor. - Atomic Add Masking:
atomic_adddoesn’t support block pointers. Therefore, we implement the pointer and mask logic fordWexplicitly.
The following figure is a representation of the backward pass for one iteration of the outer loop (i.e. one block along V and all blocks along D):

Here’s the full code for the backward pass:
This concludes the implementation of our kernel! The full code including the kernel and benchmark script is available here.
Memory Benchmark
Finally, we compare our kernel with the PyTorch baseline using hyperparameters inspired from Llama3 and an A100 GPU. Specifically, we consider a sequence length of S=16,384, a batch size of B=1 and an embedding dimension of D=4096; the vocabulary size is set to V=128,256.
As expected, the PyTorch baseline allocates a massive intermediate tensor to store the activations, resulting in a peak memory usage of 36.02GB. In comparison, our Triton kernel reduces the peak memory usage by 84% by allocating only 5.04GB using D_BLOCK=64 and V_BLOCK=64!
Using even smaller block sizes would allow for further memory gains at the cost of efficiency.

Atomic Limitations and Production Scaling
In this article, we focused on the technical and mathematical intuition behind fused Linear Cross-Entropy kernels. We used atomic operations like tl.atomic_add to keep the code minimal and readable. However, while our kernel successfully slashed memory usage by a staggering 86%, the Triton kernel is significantly slower than native PyTorch.
Unfortunately, the same atomic operations which make this kernel easier to write and comprehend come at the cost of a massive traffic jam since thousands of threads try to modify the same memory address at once. Generally, tl.atomic_add is performant when contention is low. In our current implementation, we have:
- High Contention: For the weight gradient, every single program in the batch (up to
16,384in our test) is trying to update the same memory tiles simultaneously. The hardware must serialise these updates, forcing thousands of threads to wait in line. - Numerical Non-associativity: In computers, floating-point addition is non-associative. Rounding errors can accumulate differently depending on the order of operations, which is why correctness tests might pass on a T4 but fail on an A100, the latter has more streaming multiprocessors (SMs) performing more concurrent, non-deterministic additions.
Note on Precision: On Ampere and newer architectures, the
TF32format can further contribute to these discrepancies. For strict numerical parity, one should setallow_tf32=Falseor use higher precision types during the accumulation steps.
Path to Production
To move beyond this educational implementation and toward a production-ready kernel (I recommend looking at the Liger-Kernel implementation), one could implement several optimisations:
- Replacing
dXAtomics: Since each program “owns” its row ofX, we can use simple register accumulation followed by atl.store, eliminating atomics for the input gradients entirely. - A dedicated
dWKernel: To optimise the computation ofdW, production kernels generally use a different grid strategy where each program handles a block ofWand iterates through the batch dimension, accumulating gradients locally before a single global write. - Micro-batching: Advanced implementations, such as those in the Liger-Kernel library, process the sequence by blocks along the
Ndimension, making the memory scaling constant in the sequence length rather than linear. This enables the use much larger batch sizes at a reduced memory cost.
Conclusion
This concludes our deep dive into fused linear cross-entropy kernels. Thanks for reading all the way through, and I hope this article gave you both the intuition and the practical understanding needed to build on these ideas and explore them further.
If you found this useful, consider sharing the article; it genuinely helps support the time and effort that goes into producing this work. And as always, feel free to contact me if you have questions, thoughts, or ideas for follow-ups.
Until next time! 👋



