Coconut: A Framework for Latent Reasoning in LLMs

-

Paper link: https://arxiv.org/abs/2412.06769

Released: ninth of December 2024

Figure 1. The 2 reasoning modes of Coconut. In Language Mode (left), the model uses output text tokens as inputs for the subsequent reasoning step. In Latent Mode (right), the model as a substitute feeds its previous hidden state (the output of the last hidden layer) back into itself as input. Figure taken from [1]

a high concentrate on LLMs with reasoning capabilities, and for a great reason. Reasoning enhances the LLMs’ power to tackle complex issues, fosters stronger generalization, and introduces an interpretable layer that sheds light on a model’s internal thought process.

A Major milestone in LLM reasoning is the introduction of Chain-of-Thought Reasoning (CoT)[2], which proved that guiding models to reason step-by-step results in significant improvements on arithmetic and symbolic reasoning tasks.

Despite their power, reasoning models still operate primarily inside the confines of natural language, which may limit their effectiveness. Much of the token space is dedicated to maintaining linguistic coherence reasonably than facilitating abstract reasoning. Addressing this limitation, an intriguing paper from Meta, Training Large Language Models to Reason in a Continuous Latent Space[1]proposes redeeming the chain of thought out of natural language entirely, only translating back to language when essential.

Their contribution will be summarized in three key points:

  1. Chain of Continuous Thought (Coconut): An enhanced reasoning paradigm that builds on CoT. As an alternative of counting on the ultimate text output, Coconut utilizes the model’s last embedding layer latent representations.
  2. An exploration of Coconut’s capabilities: indicating how multiple next steps in reasoning will be encoded concurrently within the latent space.
  3. A deeper evaluation of the latent reasoning process itself, in order that we are able to understand Coconut’s internal representation of knowledge.

Coconut, Simplified 

Before delving into the implementation details of Continuous Chain of Thought, it’s necessary to first establish some foundational grounds.

Given an input of sequence x = [x(1),x(2),x(3) … ,x(T)] , a Chain-Of-Thought LLM (M), which predicts the subsequent token x(t+1) based on the sequence of previous tokens x(≤t) will be formally described as:

$$M_{CoT}(x_{t+1}|x<=t) = softmax(Wx_{t})$$

Where W is the load matrix of our LLM, and x(t) is the input tokens at step t.

Coconut extends this formulation by removing the dependency on textual input tokens and as a substitute using the model’s last hidden state h(t) as input. This adaptation modifies the LLM’s predictive function into:

$$M_{Coconut}(x_{t+1}|x<=t) = softmax(Wh_{t})$$

$$H_{t} = Transformer(E_{t})$$

Where E(t) = [e(x1), e(x2), … e(xt)] represents the sequence of token embeddings, with e(⋅) denoting the embedding function. H(t)​ captures the sequence of hidden states for all tokens as much as position t.

This latest formulation allows Coconut to operate in two distinct modes: Language Mode and Latent Mode, as illustrated in Figure 1 (left and right, respectively). In Language Mode, the model functions like a normal LLM, processing textual tokens as input, while in Latent mode, it operates on the inner hidden states as a substitute.

Mode switching plays a critical role in Coconut’s training process. It not only enables the model to learn tips on how to generate meaningful latent representations but additionally facilitates the decoding of those latent thoughts. Mode transitions are controlled using two special placeholder tokens: (begin-of-thought) and (end-of-thought). Inserting at position and at position signals the model to operate in Latent Mode for tokens between positions i, and e(xj)= ).

$$E_{t}=[e_{x_{1}},e_{x_{2}},….,e_{x_{i}},h_{i},h_{i+1},..,h_{j-1},e_{x_{j}},e_{x_{j+1}},…,e_{x_{t}}]$$

Figure 2. Training strategy of Coconut, where at each training stage one language reasoning step is removed and replaced with c latent reasoning steps. Here, c is the same as 1. Figure taken from [1].

Inspired by [3], Coconut employs a multi-stage training curriculum. At each stage k, k language-based reasoning steps are replaced with L latent steps, where L=k⋅c, and c is a hyperparameter determining what number of latent steps substitute a single language reasoning step. This progression is visualized in Figure 2, where at stage k=0, the model trains purely on standard CoT examples.

The creator’s decision to use multi-stage training is to decompose the training process into easier objectives, leading to raised results. This pattern is already suggested and backed up in [3], where they proved that intermediately removing tokens enabled deeper internalization of reasoning.

Using latent thought enables end-to-end gradient-based training by replacing token-level transitions between reasoning steps with continuous hidden representations, as with this variation, the network is fully differentiable. Beyond that, it also allows the model to encode multiple possible next steps concurrently, refining the reasoning path because it advances. A deeper exploration of this mechanism is provided within the section.

For instance, let’s examine an easy example drawn from GSM8K[4], one among the datasets used to coach Coconut.

Query:

“Betty is saving money for a brand new wallet, which costs $100. Betty has only half of the cash she needs. Her parents decided to present her $15 for that purpose, and her grandparents twice as much as her parents. How way more money does Betty have to buy the wallet? “

Reasoning steps:

1.Betty has only 100 / 2 = $<<100/2=50>>50.

2.Betty’s grandparents gave her 15 * 2 = $<<15*2=30>>30.

3.This implies, Betty needs 100–50–30–15 = $<<100–50–30–15=5>>5 more.

4. Answer: 5

This query is then incorporated into the training dataset and used across three distinct stages:

Figure 3. An example of the training strategy of Coconut. Figure by author based on example taken from GSM8k[4].

As shown in Figure 3, at stage 0, no latent thoughts are present, only language-based reasoning steps followed by the ultimate answer. In subsequent stages 1 and a pair of, one language reasoning step is progressively replaced by one latent thought (since c=1), until stage 3, where all reasoning steps are latent. This procedure is applied to every training example within the dataset.


Key Findings & Evaluation

Three datasets were used to judge Coconut’s effectiveness. One focused on mathematical reasoning (GSM8K[4]) and two on logical reasoning: ProntoQA[5] and ProsQA. ProsQA (Proof with Search Query-Answering) is a modified version of ProntoQA, featuring randomly generated directed acyclic graphs (DAGs) of reasoning steps, designed to challenge the model with more complex planning tasks. All models were fine-tuned using GPT-2 as the bottom model, with c=1 for many datasets, aside from GSM8K, where two latent thoughts were used (c=2).

Below is a simplified summary of the outcomes reported within the paper:

Table 1. Accuracy results on three datasets. Results taken from [1].

The models used for comparison with the Coconut architecture are:

  • CoT: Model trained with Chain-of-Thought reasoning, utilizing full reasoning chains during training.
  • No-CoT: Model trained with none reasoning chains; standard language modeling without intermediate reasoning steps.
  • Coconut: The complete implementation proposed on this paper.
  • w/o curriculum: The Coconut model trained without the multi-stage curriculum; i.e., no gradual introduction of latent thoughts.
  • w/o thought: Coconut with multi-stage training retained, but without introducing latent thoughts. Language reasoning steps are simply removed over stages as a substitute.
  • Pause as thought [6]: Model trained without latent thoughts entirely, but special tokens are inserted rather than each removed thought. These tokens allow the model additional computation steps before generating a solution. Prior studies [7] have reported improved performance using this approach.

A detailed examination of the previous table reveals three key insights into the Coconut training paradigm.

First, latent reasoning demonstrates superior performance over Chain-of-Thought on logical reasoning tasks, outperforming it on benchmarks corresponding to ProntoQA[5] and ProsQA. The substantial accuracy gain observed in ProsQA (97.0% vs 77.5%) highlights Coconut’s effectiveness in handling more complex reasoning challenges. Unfortunately, the authors didn’t explain the accuracy loss between CoT and Coconut (42.9% vs. 34.9%). This may very well be attributable to the mathematical nature of GSM8k, which, unlike ProsQA, requires less reasoning prowess.

Second, comparing Coconut with its non-multi-stage training counterpart, we reach the identical findings suggested by [3]: breaking down the training process into simpler, more manageable tasks significantly enhances model performance. Moreover, through comparing “w/o curriculum” with “w/o thought” implementation, it is evident that the effect of gradual multi-stage training is definitely more outstanding than simply replacing language steps with latent thoughts in a single step. That is an interesting finding showing how crucial gradual training is to the ultimate results.

Lastly, even when supplying the model with multi-stage training and enough computational capability with the model, the LLM still falls short in comparison with the most important Coconut implementation. That is more apparent when comparing their GSM8K results, reinforcing the hypothesis that incorporating latent thoughts still boosts training effectiveness.


Understanding Latent Reasoning

One in every of the benefits of Coconut is that, unlike language-based thoughts, latent thoughts have the power to think about several directions or outputs of their consideration. This results in a distinct reasoning process than normal chaining, allowing us to interpret the reasoning process as a hypothetical tree search. Each depth layer is the results of a respective latent step k, and every node is a calculated probability of a particular option. This will likely be covered more in Example #2.

Two most important examples of this phenomenon are presented within the paper. We’ll cover each of them briefly for example the latent reasoning power of this latest thought paradigm.

Example #1:

The primary example demonstrates how a latent thought can contain multiple possible outcomes inside its reasoning tree. To explore this, the continual thought generated by the model was decoded using an LLM head, a process done solely for testing purposes, allowing us to probe the continual thought and confirm whether these latent thoughts were being learned appropriately.

Query:

James decides to run 3 sprints 3 times per week. He runs 60 meters each sprint. What number of meters does he run per week?

Reasoning Steps:

1. He runs 3*3=9 sprints per week

2. So he runs 9*60=540

Answer: 540

Alternative Solution:

1. He runs 3*60=180 meters per week

2. So he runs 3*180=540

After we decode the primary latent thought generated by the model, we discover that the highest three possible outputs are:

1.”180” with a probability of 0.22

2.” 180” ( with an area) with prob. of 0.20

3.”90” with prob. of 0.13

This shows that the model is indeed considering step one within the two viable solutions mentioned above.

Example #2:

The second example gives a clearer illustration of how the tree search is constructed because the variety of thoughts increases, pruning older branches which are now not relevant to the reasoning process and prioritizing more “sound” nodes.

Figure 4. Latent search tree for instance #2. On the left are the outcomes of decoding the primary latent reasoning step, and on the proper are the outcomes of the second latent step. Figure taken from [1].

Query:

“Every grimpus is a yimpus. Every worpus is a jelpus. Every zhorpus is a sterpus. Every impus is a hilpus. Every jompus is a …grimpus is a gwompus. Every rempus is a gorpus. Alex is a sterpus. Every zhorpus is a rompus. Is Alex a gorpus or bompus?”

Reasoning Steps:

1.”Alex is a grimpus.”

2. “Every grimpus is a rorpus.”

3.”Every rorpus is a bompus.”

Answer: “Alex is a bompus.”

The probability for every option will be obtained through the multiplication of each token’s probability, as depicted in Figure 4. Here we show the state of the search tree after one latent thought (left), and after two (right).

We are able to see from the overall calculated probabilities that in the first step, the least probable option (0.01) is sterpus, while the second probable option is grimpus (0.32), which is the proper first step of reasoning on this case. When the search tree is updated with information from the second thought, the node for sterpus is totally disregarded, and the brand new node with the best probability is rorpus, which is the proper second reasoning step.

This proves that Coconut has the facility of including various next steps in its reasoning process, prioritizing more necessary steps as we go (much like grimpus in the first step) and disregarding less relevant ones (sterpus in the first step). This shows that Coconut has the power to navigate several thoughts in a tree manner, until it reaches its final conclusion.


Conclusion

On this post, we now have discussed Coconut, a brand new reasoning paradigm elevating LLMs from the need of “pondering” in language space, and utilizing the latent space as a substitute. We’ve got discussed Coconut’s significant performance in comparison with other reasoning methods, covered the importance of multi-stage training, and given examples to prove and understand how the latent reasoning process works under the hood.

In my view, Coconut addresses an interesting research topic, sparking latest exploration into latent reasoning approaches, paving the way in which for the creation of more sophisticated machine reasoning models that aren’t sure by language syntax.


References

[1] S. Hao, S. Sukhbaatar, D. Su, X. Li, Z. Hu, J. Weston and Y. Tian, (2024), arXiv preprint arXiv:2412.06769

[2] J. Wei, X. Wang, D. Schuurmans, M. Bosma, B. Ichter, F. Xia, E. Chi, Q. Le and D. Zhou, (2022), arXiv preprint arXiv:2201.11903

[3] Y. Deng, Y. Choi and S. Shieber, (2024), arXiv preprint arXiv:2405.14838

[4] K. Cobbe, V. Kosaraju, M. Bavarian, M. Chen, H. Jun, L. Kaiser, M. Plappert, J. Tworek, J. Hilton, R. Nakano, C. Hesse and J. Schulman, (2021), arXiv preprint arXiv:2110.14168

[5] A. Saparov and H. He, (2022), arXiv preprint arXiv:2210.01240 

[6] S. Goyal, Z. Ji, A. S. Rawat, A. K. Menon, S. Kumar and V. Nagarajan, (2024), arXiv preprint arXiv:2310.02226

[7] J. Pfau, W. Merrill and S. R. Bowman, (2024), arXiv preprint arXiv:2404.15758

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x