KV Cache from scratch in nanoVLM

-


We now have implemented KV Caching from scratch in our nanoVLM repository (a small codebase to coach your personal Vision Language Model with pure PyTorch). This gave us a 38% speedup in generation. On this blog post we cover KV Caching and all our experiences while implementing it. The teachings learnt are general and could be applied to all autoregressive language model generations. Implementing from scratch on a small codebase is an important learning experience, come along for the ride!

bar plot showcasing improvement in generation speed



Introduction

Autoregressive language models generate text by sampling one token at a time. During inference, the model processes a given input sequence, predicts the subsequent token, appends it to the sequence, and repeats this process until some stopping criterion:

diagram for autoregression

This step-by-step generation is inherently sequential:

  • To generate token ti+1 t_{i+1}
  • Although transformers are internally parallel, each latest prediction requires a full forward go through all transformer layers, which incurs a quadratic memory/compute when it comes to the sequence length.

This repetition also results in computational redundancy. On this post, we explore KV Caching, an optimisation technique that mitigates this inefficiency.

Table of contents:



Revisiting the Transformer Architecture

Before diving into caching, let’s revisit how attention operates in transformer models. A Transformer language model consists of stacked layers, each composed of:

  • Multi-head self-attention
  • Feed-forward network (MLP)
  • Residual connections and layer normalisation

To grasp where KV Caching helps, we give attention to the self-attention mechanism, specifically inside a single attention head.

Let’s walk through an easy PyTorch implementation to visualise the important thing computations.

import torch

input_seq_length = 5
dim_model = 10

input_ids_emb = torch.randn(input_seq_length, dim_model)
W_q = torch.randn(dim_model, dim_model)
W_k = torch.randn(dim_model, dim_model)
W_v = torch.randn(dim_model, dim_model)

Q = input_ids_emb @ W_q
K = input_ids_emb @ W_k
V = input_ids_emb @ W_v



Self-Attention Computation

For a sequence of T T input embeddings represented as X∈RT×D X in mathbb{R}^{T times D}

  • Q=XWQ Q = XW_Q
  • K=XWK K = XW_K
  • V=XWV V = XW_V
  • Causal mask M M to stop future token access

The ultimate output is:

Attention(X;Q,K,V)=softmax(QK⊤⋅Mdk)V text{Attention}(X; Q, K, V) = text{softmax}left( frac{QK^top cdot M}{sqrt{d_k}} right)V

Here’s a minimal PyTorch equivalent using a causal mask:

import torch.nn.functional as F
import math

d_k = K.shape[-1]
attention_scores = (Q @ K.T) / math.sqrt(d_k)


causal_mask = torch.tril(torch.ones(input_seq_length, input_seq_length))
masked_scores = attention_scores.masked_fill(causal_mask == 0, float('-inf'))

attention_weights = F.softmax(masked_scores, dim=-1)
output = attention_weights @ V



Where Redundancy Creeps In

In autoregressive generation, the model generates one token at a time. With each step, it recomputes Q Q , K K , and V V for your complete sequence, regardless that the sooner tokens haven’t modified.

new_token_emb = torch.randn(1, dim_model)
extended_input = torch.cat([input_ids_emb, new_token_emb], dim=0)

Q_ext = extended_input @ W_q
K_ext = extended_input @ W_k
V_ext = extended_input @ W_v


To verify the redundancy:

torch.testing.assert_close(K, K_ext[:input_seq_length]) 
torch.testing.assert_close(V, V_ext[:input_seq_length]) 

These checks show that for all but the latest token, K K and V V are an identical to previously computed values.

Original (5×5):         Prolonged (6×6):
â–  â–  â–  â–  â–               â–  â–  â–  â–  â–  â–¡
â–  â–  â–  â–  â–               â–  â–  â–  â–  â–  â–¡
■ ■ ■ ■ ■    →         ■ ■ ■ ■ ■ □
â–  â–  â–  â–  â–               â–  â–  â–  â–  â–  â–¡
â–  â–  â–  â–  â–               â–  â–  â–  â–  â–  â–¡
                       â–¡ â–¡ â–¡ â–¡ â–¡ â–¡
  • â–  = Already computed and reused
  • â–¡ = Recomputed unnecessarily

A lot of the attention computation is repeated needlessly. This gets costlier as sequences grow.



How KV Caching Fixes It

To eliminate this inefficiency, we use KV Caching:

  • After processing the initial prompt, we cache the computed keys ( K K ) and values ( V V ) for every layer.
  • During generation, we only compute K K and V V for the brand new token, and append them to the cache.
  • We compute Q Q for the present token and use it with the cached K K and V V to get the output.

This changes generation from full-sequence re-computation to a light-weight, incremental update.

✅ In practice, this cache is a per-layer dictionary with keys “key” and “value”, each of shape (batch_size, num_heads, seq_len_cached, head_dim).

That is the muse of how modern LLMs can generate long outputs efficiently.



KV Caching in nanoVLM: From Theory to Practice

Now that we understand the speculation behind KV Caching, let’s see the way it’s implemented in practice inside our nanoVLM repository. That is a really perfect testbed, because it’s a brilliant concise and self-contained codebase.

KV caching is enabled across three key components in our model:

  1. The Attention block that uses and updates the KV cache
  2. The Language model that tracks cache per layer
  3. The Generation loop that separates prefill (the initial pass with the input prompt) and sequential decode phases



1. Updating KV Cache within the Attention Block

Within the LanguageModelGroupedAttention class, we modify the forward function to simply accept and update a cache of keys and values (block_kv_cache).

Previously, the model recomputed K K and V V at every generation step. Now we only compute Klatest K_{text{latest}}

def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None):
    is_prefill = block_kv_cache is None
    B, T_curr, C = x.size()

    
    q_curr, k_curr, v_curr = project_current_tokens(x)
    q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin)

    if not is_prefill and block_kv_cache['key'] is not None:
        
        k = torch.cat([block_kv_cache['key'], k_rotated], dim=2)
        v = torch.cat([block_kv_cache['value'], v_curr], dim=2)
    else:
        
        k, v = k_rotated, v_curr

    block_kv_cache = {'key': k, 'value': v}
    return attention_output, block_kv_cache



2. Tracking Cache Across Layers

Within the LanguageModel class, we introduce layer-wise cache tracking. The start_pos argument helps the model compute correct rotary positional encodings for newly generated tokens.

def forward(self, x, kv_cache=None, start_pos=0):
    T_curr = x.size(1)
    position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device)
    cos, sin = self.rotary_embd(position_ids)

    for i, block in enumerate(self.blocks):
        
        x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i])

    return x, kv_cache
  • kv_cache: A listing of dictionaries, one per transformer layer, holding previous keys and values.
  • start_pos: Ensures that rotary embeddings are aligned with current generation index.



3. Prefill vs Decode within the Generation Loop

The largest architectural change is within the generate() approach to the VisionLanguageModel.

We split generation into two stages:

  • PREFILL PHASE: Encode the complete prompt and construct the initial cache.
  • DECODE PHASE: Generate tokens one after the other using cached keys/values.
PREFILL PHASE (cache construction)
[Prompt: "What is"] → [Transformer] → [Cache: K, V for all layers]

DECODE PHASE (token-by-token)
[Token: "the"] → [Q("the") + cached K/V] → [next token: "?"] → ...

Here’s the corresponding code:


prompt_output, kv_cache_list = self.forward(
    inputs,
    kv_cache=None,
    start_pos=0
)


for i in range(max_new_tokens):
    next_token = sample_from(prompt_output)

    decode_output, kv_cache_list = self.forward(
        next_token,
        kv_cache=kv_cache_list,
        start_pos=current_position  
    )

    prompt_output = decode_output

By separating these phases, we avoid redundant computation and dramatically speed up inference, especially for long prompts.



Summary of Changes

Module Original Behaviour Latest Behaviour
LanguageModelGroupedAttention.forward Recomputes Q Q , K K , V V on every step Uses and updates KV cache
LanguageModel.forward No memory of previous state Tracks per-layer KV cache, handles start_pos
VisionLanguageModel.generate One-phase generation loop Split into prefill and decode phases



Summary: Why KV Caching Matters

Profit Explanation
Incremental growth Cache grows by one row per latest token
Position-aware decoding start_pos ensures correctness of position encoding calculations
Efficiency Reduces per-token inference to O(seq len) as a substitute of quadratic

KV caching eliminates unnecessary computation during autoregressive generation, enabling faster and more efficient inference, especially in long sequences and real-time applications. This can be a trade-off between speed and memory, and its drawbacks could be more complex code and restricting fancier inference schemes, like beam-search, etc. KV caching is a preferred method for speeding up LLM inference, making it possible to run them on consumer hardware, and now you already know how it really works too!



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x