Continuous batching from first principles

-



Title card

TL;DR: on this blog post, ranging from attention mechanisms and KV caching, we derive continuous batching by optimizing for throughput.

When you’ve ever used Qwen, Claude, or another AI chatbot, you’ve got probably noticed something: it takes some time for the primary word of the response to look, after which words appear one-by-one in your screen with (hopefully) a daily and fast-paced frequency. That is because at the guts of it, all LLMs are only fancy next token predictors. An LLM first processes your entire prompt to supply one recent token. Then it keeps adding tokens one after the other, every time reading all the things that got here before, until it decides generation is over.

This generation process is computationally expensive: it requires passing the input through billions of parameters for every token generated. To make these models practical for real-world applications, particularly when serving many users concurrently, researchers and engineers have developed a spread of efficient inference techniques.
One of the vital impactful optimizations is continuous batching, which attempts to maximise performance by processing multiple conversations in parallel and swapping them out after they are done.

To grasp how continuous batching works and why it is so effective in high-load serving scenarios, we’ll construct up from the basics of how LLMs process tokens.



Attention

The eye mechanism is the central piece of how LLMs work. A language model processes text by breaking it down into pieces that we call tokens. We are able to conceptually consider “tokens” as “words”, but sometimes a word is likely to be composed of several tokens. For every token sequence, the network computes a prediction of what the subsequent token must be.

Many operations within the network are token-wise: each token is processed independently, and the output for a given token depends only on that token’s content, not on another tokens within the sequence. Operations like this include layer normalization or matrix multiplication. Nonetheless, to create connections between words in a sentence, we want operations where tokens can influence one another.

That is where attention is available in. Attention layers are the one place where different tokens interact with one another. Understanding how a network connects tokens together means understanding attention.

Let’s have a look at how this works in practice, within the case where there is just one input prompt.

Consider the initial prompt I'm sure this project, tokenized as 7 tokens: [, I, am, sure, this, pro, ject]. The , or “Starting of Sequence”, is a special token we add initially of the prompt to inform the language model that a brand new conversation starts here.

Each token is represented contained in the network with a vector of length d (the hidden dimension). Due to this fact, the seven incoming tokens form a tensor xx with shape [1,7,d]left[1, 7, d right]

Input tensor x x is then projected by three matrices: the query projection Wq W_q

proj_and_mul.png

Next, tensors Q Q and K K are multiplied together to measure similarity between tokens, producing a tensor of shape [1,n,n] left[ 1, n , n right]

We then apply a boolean attention mask to QKT QK^T

masking_and_softmax.png

Finally, after applying the eye mask, we take a token-wise softmax (which is similar as saying a row-wise softmax) and multiply the result by the worth projection V V to get the output of 1 attention head, of shape [1,n,A] left[ 1, n , A right]

attention.png

We’re going to use numerous attention visualization on this post, so to simplify things, we’re going to condense the figure above only a bit.

Why this matters: In continuous batching, Q Q , K K , and V V can have different numbers of tokens because, as we’ll see, we’ll be processing different stages (prefill and decode) at the identical time. To make it more general, to illustrate Q Q has shape [1,nQ,A] left[1, n_Q , A right]

The eye scores QKT QK^T

After applying the eye mask and row-wise softmax, we multiply by V V . Since we’re multiplying a matrix of shape [1,nQ,nK] left[ 1, n_Q , n_K right]

Moreover, since we all know that the eye mask is applied to QKT QK^T

simple_attention.png

This representation also underlines how we are able to read an attention mask.

We read the mask row-by-row, which is similar as reading token-by-token: each row corresponds to 1 token’s attention computation. A green square at position (row i, column j) means True: token j can influence token i. A white square means False: no interaction allowed.

For instance, have a look at the third row for token “am“. The “I” column is green, so “I” influences the computation of “am“. The “pro” column is white, so “pro” doesn’t influence “am” . That is causal masking at work: future tokens cannot affect past ones.

The last layer of the model outputs a token prediction for every input token. In our context, generating the continuation of a single prompt, we only care in regards to the next token prediction from the last token. The last token is “ject” within the figure above, and the associated predicton is “will“.

The method we just described, where we take an entiere input sequence, pass it through multiple attention layers and compute a rating for the subsequent token, is named prefill. It’s because, as we’ll see in a moment, much of the computation we performed might be cached and reused – hence, we’re prefilling the cache. Due to the usage of this cache, sequence generation can proceed using much less compute in a phase called decoding. Within the decoding phase, generating one recent token will likely be much faster than the initial full-sequence computation. Let’s have a look at why.

To proceed generation, we start a brand new forward pass, which might naively appear like this:

naive_generate.png

To compute the eye scores of the brand new token, we still need the important thing and value projections of the previous tokens. So we want to repeat the matrix multiplication of the old tokens (in grey within the figure above) with Wk W_k



KV-cache

Right off the bat, we notice that the last token doesn’t impact the eye calculation of the opposite tokens:

cant_see_me.png

This follows the concept of the causal mask: since “will” comes in spite of everything previous tokens, it doesn’t change their attention calculation.
For text generation, causal attention is by far essentially the most common, so we’ll concentrate on that case any longer. Have in mind that non-causal attention schemes may also be used, especially when coping with images.
Considering we only need the next-token prediction for the “will” token, we are able to simplify the eye mechanism by only computing the output for this token.

Furthermore, we already computed the K K and V V states for the tokens ““, … , “ject” in the course of the previous forward pass: in the event that they have been stored, we don’t must recompute them again. That is the KV cache: the list of key and value states created during generation. It essentially allows one to scale back the compute cost of generating token n+1 n+1

kv_cache.png

Within the figure above, only the tokens in white are computed: as a substitute of computing the keys and values for 8 tokens, we compute them for 1. You may see that through KV caching, numerous compute is saved.
You may check this post for more visualizations of KV caching, or this one for a practical implementation example.

Let’s be a bit more specific in regards to the cache size, since it’s a great opportunity to look at the shapes present in our model. For a model with L mathcal L attention layers and H H attention heads with head dimension A A , the overall cache size needed to store one token will likely be 2∗L∗AH 2 *mathcal L * AH

KV caching is helpful when we wish to generate the subsequent token, which is a stage we call decoding. But it could possibly even be useful within the prefill stage, after we process the initial prompt and have many input tokens. Especially when there are large initial prompts that do not slot in GPU memory all of sudden.



Chunked prefill

Up till now, we have now checked out an example of prefill where we have now n=7 n=7

Let’s pretend that the available memory could be very constrained, and that we are able to only pass m=4 m=4

chunked prefill.png

We are able to do this because of the KV cache. We store the KV states in the course of the first prefill split, and in the course of the second prefill split, we prepend the stored KV states to the brand new KV states. We also adapt the eye mask accordingly. Visually, it looks like we split the non-chunked prefill in the center.

The important thing insight: cached KV states allow us to process the prompt incrementally without losing information.

Although we showed here an example where we split the prefill into 2 chunks, chunked prefill might be used to separate the prefill in any way we wish, adapting flexibly to memory constraints.

We at the moment are finally equipped with all of the tools we want to know Continuous Batching.



Continuous batching

In our previous examples we have now only considered the case of batch size one, i.e. we only generate tokens for one prompt at a time. Within the context of evaluation or model serving, we wish to generate tokens for a lot of prompts. To extend the throughput, which is the variety of tokens generated per second, the perfect plan of action is to generate tokens in parallel for a batch of several prompts.

To batch prompts together, the naive way is so as to add an axis to each input tensors: token sequence and a focus mask. Nonetheless, this comes with a constraint on the form of the inputs: we want all prompts to have the identical length, because tensors have to be rectangular. To attain this, we often add padding on the left so the brand new token prediction all the time comes from the rightmost token. We also modify the eye mask of every prompt accordingly, as shown below:

padding.png

where the padding tokens are colored in orange. Then we are able to perform the forward pass as we used to, with the added dimension of the batch size. This is named batched generation: efficient for same-length prompts, but wasteful when lengths vary.
It’s illustrated below, through 4 steps of generation: one prefilling step (at the highest) and three decoding steps (below each “Forward pass” lines).

batched_generation.png

where means “End Of Sequence”, it is a special token to point the model has reached the top of generation for the corresponding sequence.

The disadvantage of batched generation is that if one prompt finishes generation before the opposite one by generating an token, all further generated tokens are useless. And this goes on until the longest request of the batch finishes. After all, we are able to remove the prompts which have reached an token from the batch and avoid wasting compute and memory, but saving resources will not be the goal here: throughput is.

As an alternative of just removing the finished prompt from the batch, we are able to replace it with a prompt that is waiting for generation. We are going to call this dynamic scheduling, or dynamic batching. Dynamic scheduling is great to keep up throughput while ensuring any token generated by a forward pass is relevant. But due to the best way we batched prompts together, it has a serious drawback: we want numerous padding when swapping prompts. That is since the newly-inserted prompt must undergo prefill while the opposite prompts are decoding one token at a time. So there is sort of as much padding as there are tokens within the newly-inserted prompt.

dynamic_batching.png

The issue becomes even worse when batch size increases and initial prompts are long. The padding cost grows quadratically with each batch size and prompt length. If we have now a batch of B B prompts which are in decoding phase and one finishes, dynamically introducing a prompt of n n initial tokens within the batch requires (n−1)(B−1) (n-1)(B-1)

Moreover, practical optimizations like CUDA graphs or torch.compile require static tensor shapes. This forces us to pad all prompts to a set maximum length, dramatically increasing the padding waste.

At this point, our essential problem is padding, which is a consequence of the axis we added to batch sentences together. Thus, the perfect can be to do away with this axis entirely, a radical rethinking of batching. If we accomplish that, the one strategy to batch prompts together is to concatenate them:

concatenate.png

But we don’t need tokens from prompt 0 to interact with the tokens of prompt 1! Luckily for us, we have now a strategy to control how tokens interact with each other: the eye mask. How we do that is displayed below:

ragged_batching.png

Although we use different tints of green for instance different parts of the eye mask, this continues to be a boolean mask with only greens for True and white for False.
This fashion of batching prompts together is named ragged batching (because sequence lengths are ‘ragged’ or uneven), and it offers the advantage of added throughput without introducing the necessity for padding tokens.

Within the figure above, we use ragged batching to mix two full prompts together, but we are able to batch as many as memory allows. The one limit is m m , the variety of tokens we are able to slot in a batch, with m m depending on the available memory on the GPU.

Ragged batching is certainly one of the important thing components of continuous batching. To maximise throughput, we are able to mix prefill and decoding sequences following an algorithm like this:

  • We attempt to all the time reach our memory budget of m m tokens per batch
  • We first add all of the prompts in decoding phase to the batch, each accounting for 1 token
  • We fill the remaining space with prefill phase prompts, counting on the flexibleness of chunked prefill to separate inputs as needed

Dynamic scheduling is the ultimate piece that contributes to the continuous batching technique: we remove finished prompts from the batch as soon as they’re done, and replace them with recent chunked prompts that correspond to incoming requests.

This mix of ragged batching and dynamic scheduling is named continuous batching, and it is the technique that powers modern LLM serving systems.

continuous_batching.png



Conclusion

Continuous batching combines three key techniques to maximise throughput in LLM serving:

  1. KV caching to avoid recomputing past token representations
  2. Chunked prefill to handle variable-length prompts inside memory constraints
  3. Ragged batching with dynamic scheduling to eliminate padding waste and keep the GPU fully utilized

By removing the batch dimension and using attention masks to regulate token interactions, continuous batching allows mixing prefill and decode phases in the identical batch, dramatically improving efficiency for serving multiple requests. For this reason services like ChatGPT can handle hundreds of concurrent users efficiently.

In the subsequent article on this series, we’ll explore efficient KV cache management through paged attention. When you’d prefer to see a deep dive on other continuous batching topics, please tell us within the comments!

Acknowledgement: because of Arthur Zucker for producing the initial concept for the figures utilized in this text. And because of Arthur Zucker, Luc Georges, Lysandre Debut, Merve Noyan and Pedro Cuenca for all providing helpful reviews.



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x