Why We’ve Been Optimizing the Incorrect Thing in LLMs for Years

-

Standard Large Language Models (LLMs) are trained on a straightforward objective: Next-Token Prediction (NTP). By maximizing the probability of the immediate subsequent token , given the previous context, models have achieved remarkable fluency and reasoning capabilities.

Nonetheless, this approach is basically inefficient because the model has to spend the identical amount of compute in predicting filler words (eg, “the”, “and”, “have”) as information-carrying words (eg, “red”, “apple”, “lazy”). That is exacerbated by the undeniable fact that greater than 50% of the words you see within the English language are filler (Nordquist, 2024)3. This raises a practical query: Do all words need a full inference cycle to be predicted, or do models have already got the filler words of their hidden states long before they’re predicted?

Motivation For MTP

The concept transformers are able to processing greater than just the immediate next step is supported by recent empirical research. (Pal et al., 2023)1 demonstrated that the inner representations of transformer models often encode trajectories of future text long before they’re generated.

As an instance, the researchers performed a “transplantation” experiment. They extracted the hidden states from a model processing the sentence “Madison Square Garden is situated in…”— just before it was about to predict the following word as “Recent.” They then placed this vector right into a model processing a totally unrelated context, corresponding to “Tell me something about…” Despite the unrelated prompt, the model autoregressively accomplished the sentence as “Tell me something about Recent York City.” This confirmed that the model did not only encode solely for the following token, but for the complete future sequence.

To capitalize on this latent capability of LLMs, researchers at Meta FAIR (Gloeckle et al., 2024)2 propose a novel approach. As an alternative of treating this foresight as an emergent byproduct, they explicitly use it as a training objective. By tasking the model with predicting “” future tokens concurrently at each position as an alternative of only one, they were effectively capable of make the model look ahead. The authors exhibit that the Multi-Token Prediction (MTP) paradigm yields significantly stronger performance on various benchmarks while boosting inference speeds to as much as 3 times faster than the baseline.

The MTP Architecture: Parallelizing Prediction

If the data for the following few tokens is already embedded in the present hidden states of LLMs, the query then becomes architectural: How will we extract this information upfront, without increasing the compute requirements compared to straightforward NTP?

The architecture proposed by the authors goals to change the prevailing transformer backbone to predict future tokens concurrently. Unlike the usual NTP paradigm, where the cross-entropy loss is minimized for the immediate next token () only, Multi-Token Prediction (MTP) minimizes the common loss over different output heads:

(Source: Creator)
xt+i: Represents future “i” tokens
x1:t: Represents the prompt context
Pθ: Represents the complete Model as a function

To implement this, the authors divide the model into two components:

  1. A Shared Trunk (): The majority of the model is a normal transformer backbone, whose job is to process the prompted context into an information-dense global representation ​, which will probably be used for all subsequent predictions.
  2. Independent Heads (​​): The output of the trunk is fed to independent heads. Each head has its own transformer layer and is liable for predicting a future offset token (e.g., head 1 predicts , head 2 predicts , etc.).

Ultimately, the output of every individual head is passed to the shared un-embedding layer, which is implemented as a straightforward linear projection from the model’s hidden dimension to the length of the vocabulary. The diagram below serves to sum up an important facets of the MTP architecture:

(Source: Creator)
The model processes the shared trunk just once. Then, it prompts each head sequentially. For steps 4-6, it prompts the primary head, calculates its logits, after which backpropagates the changes in steps 6-8. Head 2 is activated similarly, followed by heads 3 and 4.

Overcoming the Memory Bottleneck

The architecture described above presents a major engineering hurdle: GPU memory utilization.

The vocabulary size (V) of Large Language Models is often within the realm of 32k-256k, which is astronomically big. This makes the raw prediction scores for each word within the vocabulary, aka the output logits, also very big. In a normal NTP setup, the model must materialize these logits just once per step, making it tractable. Nonetheless, within the MTP setup, different sets of those massive logits are produced concurrently, which may easily overwhelm the GPU memory. This makes the MTP method impractical for researchers, unless they drastically reduce batch sizes, slowing down the complete training process.

The authors circumvent this bottleneck with a sequential forward/backward pass strategy. Reasonably than computing the loss for all  heads directly, the training loop iterates through them sequentially:

  1. The shared trunk computes the latent state ​.
  2. The model computes the logits for head 1, calculates the loss, backpropagates gradients throughout the complete model, and immediately discards the logits from memory.
  3. It then repeats this process for head 2, head 3, and so forth.

By deleting these massive logit vectors from memory after each head computation, the height memory usage of the training process stays O(V) as an alternative of O(nV). This permits the MTP models to be trained in similar batch sizes as the usual models.

Critical Design Selections

Beyond memory optimization, the authors also made two specific design decisions which might be essential to know the performance metrics and scientific validity of MTP.

1. The Parameter Parity Constraint
In an MTP model with n=4 heads, the 4 additional head layers with transformer backbones result in a rise in parameters. To compensate for this increase, the authors removed an equivalent variety of layers from the model’s trunk, making it shallower. This is completed in order that any performance changes within the MTP with respect to the baseline might be solely credited to the MTP architecture itself, and never to the rise in parameters of the model.

The undeniable fact that MTP still outperforms standard NTP-based models despite having a shallower trunk only goes on to indicate the merits of the architecture.

2. Head Topology: Parallel vs. Causal
The authors also experimented with the arrangement of the heads themselves, specifically comparing two approaches:

  • Parallel Heads: That is the usual MTP design described above. On this design, every head predicts its specific future token based only on the shared state ​, without seeing the predictions of other heads.
  • Causal Heads: On this setup, head 2 (predicting ) would receive the output of head 1 as input. This creates a “mini-autoregressive” chain at the tip of the model, which allows each head to take a look at the state of the previous head. The architecture of MTP with n=4 causal heads is given below:
(Source: Creator)
Within the causal design, heads are arranged in a sequential order. This is completed in order that each head knows what the top preceding it predicted.

Surprisingly, the Parallel design performed higher. The authors hypothesize that within the design with causal heads, the shared trunk “got lazy,” counting on the heads to work out the sequential information. But by forcing the heads to act independently, the trunk was effectively coerced into learning a world representation, which could satisfy all heads directly. That is the precise property that also manifests itself because the model’s ability to plan into the longer term, which is important in reasoning tasks.

Experimental Results: The Scale of Improvement

The authors conducted extensive evaluations comparing MTP models against standard Next-Token Prediction (NTP) baselines across model sizes starting from 300M to 13B parameters.

1. The “Scaling Law” of Multi-Token Prediction
Arguably, probably the most interesting finding is that the model’s performance scales with its size. For smaller models from 300M-1.3B parameters, the difference between MTP and NTP is negligible (oftentimes MTP performs worse). But as the dimensions increases, MTP starts to perform significantly higher than the baseline. As illustrated below, MTP outperforms NTP by 17% on the MBPP benchmark and 12% on the HumanEval benchmark.

(Source: Adapted from Gloeckle et al. (2024b), Figure 3)
Note: These graphs depict absolutely the point changes in comparison with the baseline. For instance, in the highest left graph, the 13B NTP model scored 26% on the MBPP benchmark while MTP scored 30.5%, which is a 4.5% point increase in absolute terms and 17% increase in relative terms.

A possible reason behind this disparity could stem from the undeniable fact that larger models, with their larger parameter counts, can afford to allocate more capability to future planning than smaller models can. This permits the larger models to reap the benefits of the multi-token objective to develop superior reasoning.

2. Three-Fold Inference Speedup via Self-Speculation
Other than performance metrics, MTP also solves probably the most persistent bottlenecks in LLM operations: inference latency.

To completely appreciate this contribution, we must first understand what Speculative Decoding is. In standard inference, the model has to iteratively generate tokens. It has to attend for to be generated before computing . Speculative decoding speeds this process up through the use of a smaller, faster draft model (often of the identical family because the predominant model but with many fewer parameters), which takes within the hidden state from the predominant model and predicts the following few tokens. The predominant model is then tasked to confirm all of those tokens in a single forward pass, ensuring it agrees with the predictions of the smaller model. Since a single forward pass is quicker than generating tokens through quite a few iterations, this leads to a net speedup.

Speculative decoding generally requires a smaller model to be loaded into memory, which might be memory-intensive. Nonetheless, the authors propose that the additional MTP heads—often discarded after training—might be used to serve the role of a built-in draft model. As these heads share the identical trunk, these heads are highly accurate drafters. Through the use of as much as 4 heads to draft a subsequence after which verifying it in parallel, MTP achieves a 3x speedup in inference with zero loss in performance accuracy.

4. Faster Formation of “Induction Heads”
The authors also analyze the emergence of induction capabilities in MTP. Induction heads are circuits in transformers which might be mainly liable for pattern-matching abilities (e.g., recognizing that [A]…[B]…[A] is probably going followed by [B]). The graph below shows that for smaller model sizes, MTP shows a greater induction ability than similarly sized NTP models. This means that by forcing the model to predict the implications of the immediate next token, it creates a gradient signal that’s conducive to the emergence of pattern recognition and in-context learning.

(Source: Adapted from Gloeckle et al. (2024b), Figure 7)
The authors took 100 children’s stories and replaced the names of characters with names that span two tokens. The induction success plotted on the y-axis is the accuracy with which the model accurately predicts the second token of the two-token names, on condition that the name has been shown to the model at the least once before.

5. Unlocking Byte-Level Training
In a more radical experiment, the authors applied MTP to byte-level models, which predict a sequence of bytes as an alternative of token representations. Historically, byte-level models have all the time performed poorly because contextual information amongst bytes is weak, and byte sequences are inclined to develop into very large. Nonetheless, as demonstrated within the table below, with heads (predicting 8 bytes directly), the MTP model significantly outperforms the baseline NTP with head, consistently across all three benchmarks. This means that the MTP model can efficiently navigate the byte-realm, allowing models to process raw data natively with none compromises in performance.

(Source: Adapted from Gloeckle et al. (2024b), Table 1)
This table presents the Pass@k accuracies of the MTP and NTP models on different benchmarks. For instance, the column @10 measures the probability that at the least one in all the highest 10 solutions generated by the model is correct.

The Price of Foresight: Shortcomings and Trade-offs

While Multi-Token Prediction offers a compelling alternative to the usual paradigm, the paper’s results make clear that it isn’t a universal “silver bullet.” The architecture introduces specific trade-offs that engineers must consider.

1. Regression on Knowledge-Intensive Task
While MTP improves reasoning (tips on how to structure a solution), it appears to harm retrieval (knowing a selected fact).
As shown below, MTP models dominate in code generation and reasoning benchmarks, but actually underperform the baseline on standard NLP tasks, including benchmarks like MMLU, TriviaQA, and ARC Challenge (which test fact retrieval and world knowledge).

(Source: Adapted from Gloeckle et al. (2024b), Figure 7)
The common accuracy across 7 benchmarks, namely arc challenge, copa, hellaswag, nq, piqa, siqa, and tqa, is plotted on the y-axis against the training steps on the x-axis.

A possible explanation might be that answering recall-based questions like “What’s the capital of France?” requires a precise give attention to the word “Paris”. By forcing the model to predict multiple tokens directly, as in “Paris is a city in…,” it’d dilute the general signal from probably the most critical token, tanking the model’s performance on the general benchmark. In case your aim is to construct a RAG (Retrieval Augmented Generation) system or a Trivia bot, MTP might actually be detrimental.

2. The “Goldilocks” Sensitivity of
There is no such thing as a “more is healthier” rule here. The authors found that performance is very sensitive to the variety of heads ().

The authors also concluded that the variety of heads () doesn’t scale linearly with MTP performance. There exists a “sweet spot” where the model can most efficiently exploit the MTP paradigm:

  • Too few (): Negligible gain, because the model doesn’t receive enough incentive to develop any foresight.
  • Too many (): Performance degrades rapidly, as the data for all 8 heads starts to overcrowd the hidden state of the shared trunk.
  • Excellent (): Best performance

This introduces a brand new hyperparameter that have to be tuned. Unlike Next-Token Prediction, which just “works,” MTP requires finding the precise horizon that matches the complexity of your data.

Conclusion

With its demonstrated ability to enhance coding performance and inference speedups, one obvious query stays: 

The reply to it is definitely DeepSeek-V3.

Of their technical report (Liu et al., 2024)4, the DeepSeek team revealed that MTP was a core component during training of the model. Much like Meta, they performed vigorous ablation studies comparing standard NTP models against MTP at each the 15.7B and 228.7B parameter scales. Using a configuration of =2 during training (predicting one extra future token), they found that MTP-trained models consistently outperformed their NTP counterparts across all datasets, like MMLU, PILE-test, HumanEval, MBPP, etc. Furthermore, by keeping that second prediction head during inference for speculative decoding as described earlier, DeepSeek achieved an inference speedup of as much as 1.8x.

This successful deployment by DeepSeek serves as practical validation for MTP to be widely used as a training objective in Large Language Models, because it demonstrates a transparent path to improving the reasoning capabilities and inference efficiency of the model with minimal associated drawbacks.

References

[1] Pal, Koyena, et al. “Future lens: Anticipating subsequent tokens from a single hidden state.”  (2023).
[2] Gloeckle, Fabian, et al. “Higher & faster large language models via multi-token prediction.”  (2024).
[3] Nordquist, R. (2024, July 20). . ThoughtCo.
[4] Liu, Aixin, et al. “Deepseek-v3 technical report.”  (2024).

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x