TL;DR: on this blog post, ranging from attention mechanisms and KV caching, we derive continuous batching by optimizing for throughput.
When you’ve ever used Qwen, Claude, or another AI chatbot, you’ve got probably noticed something: it takes some time for the primary word of the response to look, after which words appear one-by-one in your screen with (hopefully) a daily and fast-paced frequency. That is because at the guts of it, all LLMs are only fancy next token predictors. An LLM first processes your entire prompt to supply one recent token. Then it keeps adding tokens one after the other, every time reading all the things that got here before, until it decides generation is over.
This generation process is computationally expensive: it requires passing the input through billions of parameters for every token generated. To make these models practical for real-world applications, particularly when serving many users concurrently, researchers and engineers have developed a spread of efficient inference techniques.
One of the vital impactful optimizations is continuous batching, which attempts to maximise performance by processing multiple conversations in parallel and swapping them out after they are done.
To grasp how continuous batching works and why it is so effective in high-load serving scenarios, we’ll construct up from the basics of how LLMs process tokens.
Attention
The eye mechanism is the central piece of how LLMs work. A language model processes text by breaking it down into pieces that we call tokens. We are able to conceptually consider “tokens” as “words”, but sometimes a word is likely to be composed of several tokens. For every token sequence, the network computes a prediction of what the subsequent token must be.
Many operations within the network are token-wise: each token is processed independently, and the output for a given token depends only on that token’s content, not on another tokens within the sequence. Operations like this include layer normalization or matrix multiplication. Nonetheless, to create connections between words in a sentence, we want operations where tokens can influence one another.
That is where attention is available in. Attention layers are the one place where different tokens interact with one another. Understanding how a network connects tokens together means understanding attention.
Let’s have a look at how this works in practice, within the case where there is just one input prompt.
Consider the initial prompt I'm sure this project, tokenized as 7 tokens: [. The , or “Starting of Sequence”, is a special token we add initially of the prompt to inform the language model that a brand new conversation starts here.
Each token is represented contained in the network with a vector of length d (the hidden dimension). Due to this fact, the seven incoming tokens form a tensor with shape . 1 is the variety of sequences, or batch size, which is only one in our case. 7 is the sequence length, and d is the hidden dimension, or the dimensions of every token representation. Going forward, we’ll use as a substitute of 7 because the sequence length.
Input tensor is then projected by three matrices: the query projection , the important thing projection and the worth projection . This produces three tensors , and , all of shape , where is the dimension of the eye head. We call them the query, key and value states, respectively. That is represented on the left within the figure below.
Next, tensors and are multiplied together to measure similarity between tokens, producing a tensor of shape . For this reason we are saying that spotlight has quadratic complexity in sequence length. Computing requires operations, so the associated fee is a square of the sequence length. It’s represented on the proper within the figure above.
We then apply a boolean attention mask to to regulate which tokens can interact, as represented within the figure below. On this figure, the eye mask is a causal mask, meaning each token only interacts with tokens that got here before it. This follows the intuition that a cause must come before its consequence, hence the name causal mask. The eye mask is crucial since it dictates all token interactions within the network. Set all attention mask values to False and no token will ever interact with one other in the entire network. We’ll examine attention masks more closely in just a few paragraphs.
Finally, after applying the eye mask, we take a token-wise softmax (which is similar as saying a row-wise softmax) and multiply the result by the worth projection to get the output of 1 attention head, of shape . We provide a visible summary of the entire process in the next figure.
We’re going to use numerous attention visualization on this post, so to simplify things, we’re going to condense the figure above only a bit.
Why this matters: In continuous batching, , , and can have different numbers of tokens because, as we’ll see, we’ll be processing different stages (prefill and decode) at the identical time. To make it more general, to illustrate has shape , has shape , and has shape .
The eye scores then have shape , and the eye mask has the identical shape because it’s applied point-wise to the scores.
After applying the eye mask and row-wise softmax, we multiply by . Since we’re multiplying a matrix of shape by certainly one of shape , the inner dimensions must match: . This implies and all the time have the identical length, so we are able to simplify our visualizations by only showing .
Don’t be concerned if this seems abstract: the figures will make it concrete.
Moreover, since we all know that the eye mask is applied to , we all know they’ve the identical shape. As an alternative of representing the eye scores, we’ll represent the eye mask as a replacement.
Finally, since , and are direct projections of , no must represent . This offers the simplified figure where we only represent , and the eye mask:
This representation also underlines how we are able to read an attention mask.
We read the mask row-by-row, which is similar as reading token-by-token: each row corresponds to 1 token’s attention computation. A green square at position (row i, column j) means True: token j can influence token i. A white square means False: no interaction allowed.
For instance, have a look at the third row for token “am“. The “I” column is green, so “I” influences the computation of “am“. The “pro” column is white, so “pro” doesn’t influence “am” . That is causal masking at work: future tokens cannot affect past ones.
The last layer of the model outputs a token prediction for every input token. In our context, generating the continuation of a single prompt, we only care in regards to the next token prediction from the last token. The last token is “ject” within the figure above, and the associated predicton is “will“.
The method we just described, where we take an entiere input sequence, pass it through multiple attention layers and compute a rating for the subsequent token, is named prefill. It’s because, as we’ll see in a moment, much of the computation we performed might be cached and reused – hence, we’re prefilling the cache. Due to the usage of this cache, sequence generation can proceed using much less compute in a phase called decoding. Within the decoding phase, generating one recent token will likely be much faster than the initial full-sequence computation. Let’s have a look at why.
To proceed generation, we start a brand new forward pass, which might naively appear like this:
To compute the eye scores of the brand new token, we still need the important thing and value projections of the previous tokens. So we want to repeat the matrix multiplication of the old tokens (in grey within the figure above) with and to retrieve a result that was already computed once before. In other terms, we’re wasting compute. Let’s have a look at how we are able to avoid that.
KV-cache
Right off the bat, we notice that the last token doesn’t impact the eye calculation of the opposite tokens:
This follows the concept of the causal mask: since “will” comes in spite of everything previous tokens, it doesn’t change their attention calculation.
For text generation, causal attention is by far essentially the most common, so we’ll concentrate on that case any longer. Have in mind that non-causal attention schemes may also be used, especially when coping with images.
Considering we only need the next-token prediction for the “will” token, we are able to simplify the eye mechanism by only computing the output for this token.
Furthermore, we already computed the and states for the tokens “
Within the figure above, only the tokens in white are computed: as a substitute of computing the keys and values for 8 tokens, we compute them for 1. You may see that through KV caching, numerous compute is saved.
You may check this post for more visualizations of KV caching, or this one for a practical implementation example.
Let’s be a bit more specific in regards to the cache size, since it’s a great opportunity to look at the shapes present in our model. For a model with attention layers and attention heads with head dimension , the overall cache size needed to store one token will likely be with an element of to account for each and .
As an example, Llama-2-7B with layers, heads, and requires values per token per layer. With float16 precision, this takes bytes KB in memory.
KV caching is helpful when we wish to generate the subsequent token, which is a stage we call decoding. But it could possibly even be useful within the prefill stage, after we process the initial prompt and have many input tokens. Especially when there are large initial prompts that do not slot in GPU memory all of sudden.
Chunked prefill
Up till now, we have now checked out an example of prefill where we have now tokens, but in practice initial prompts might be for much longer. As an example, when using Cursor, you possibly can add your repository to the prompt, where it acts as context: this significantly increases the prompt size. In such cases, the memory needed to store the activations for tokens might be larger than the available memory on the GPU. Thus we cannot perform prefill in a single forward pass: we have now to separate the prefill in chunks. This is named chunked prefill, and it may be certainly one of the components needed to enable efficient inference.
Let’s pretend that the available memory could be very constrained, and that we are able to only pass tokens per forward pass. If we have now an initial prompt with tokens, we want to separate it in chunks (rounding up 7/4 = 1.75 to 2). We illustrate the instance below using the identical and notations:
We are able to do this because of the KV cache. We store the KV states in the course of the first prefill split, and in the course of the second prefill split, we prepend the stored KV states to the brand new KV states. We also adapt the eye mask accordingly. Visually, it looks like we split the non-chunked prefill in the center.
The important thing insight: cached KV states allow us to process the prompt incrementally without losing information.
Although we showed here an example where we split the prefill into 2 chunks, chunked prefill might be used to separate the prefill in any way we wish, adapting flexibly to memory constraints.
We at the moment are finally equipped with all of the tools we want to know Continuous Batching.
Continuous batching
In our previous examples we have now only considered the case of batch size one, i.e. we only generate tokens for one prompt at a time. Within the context of evaluation or model serving, we wish to generate tokens for a lot of prompts. To extend the throughput, which is the variety of tokens generated per second, the perfect plan of action is to generate tokens in parallel for a batch of several prompts.
To batch prompts together, the naive way is so as to add an axis to each input tensors: token sequence and a focus mask. Nonetheless, this comes with a constraint on the form of the inputs: we want all prompts to have the identical length, because tensors have to be rectangular. To attain this, we often add padding on the left so the brand new token prediction all the time comes from the rightmost token. We also modify the eye mask of every prompt accordingly, as shown below:
where the padding tokens are colored in orange. Then we are able to perform the forward pass as we used to, with the added dimension of the batch size. This is named batched generation: efficient for same-length prompts, but wasteful when lengths vary.
It’s illustrated below, through 4 steps of generation: one prefilling step (at the highest) and three decoding steps (below each “Forward pass” lines).
where means “End Of Sequence”, it is a special token to point the model has reached the top of generation for the corresponding sequence.
The disadvantage of batched generation is that if one prompt finishes generation before the opposite one by generating an token, all further generated tokens are useless. And this goes on until the longest request of the batch finishes. After all, we are able to remove the prompts which have reached an token from the batch and avoid wasting compute and memory, but saving resources will not be the goal here: throughput is.
As an alternative of just removing the finished prompt from the batch, we are able to replace it with a prompt that is waiting for generation. We are going to call this dynamic scheduling, or dynamic batching. Dynamic scheduling is great to keep up throughput while ensuring any token generated by a forward pass is relevant. But due to the best way we batched prompts together, it has a serious drawback: we want numerous padding when swapping prompts. That is since the newly-inserted prompt must undergo prefill while the opposite prompts are decoding one token at a time. So there is sort of as much padding as there are tokens within the newly-inserted prompt.
The issue becomes even worse when batch size increases and initial prompts are long. The padding cost grows quadratically with each batch size and prompt length. If we have now a batch of prompts which are in decoding phase and one finishes, dynamically introducing a prompt of initial tokens within the batch requires padding tokens. As an example, with and , we would need padding tokens!
Moreover, practical optimizations like CUDA graphs or torch.compile require static tensor shapes. This forces us to pad all prompts to a set maximum length, dramatically increasing the padding waste.
At this point, our essential problem is padding, which is a consequence of the axis we added to batch sentences together. Thus, the perfect can be to do away with this axis entirely, a radical rethinking of batching. If we accomplish that, the one strategy to batch prompts together is to concatenate them:
But we don’t need tokens from prompt 0 to interact with the tokens of prompt 1! Luckily for us, we have now a strategy to control how tokens interact with each other: the eye mask. How we do that is displayed below:
Although we use different tints of green for instance different parts of the eye mask, this continues to be a boolean mask with only greens for True and white for False.
This fashion of batching prompts together is named ragged batching (because sequence lengths are ‘ragged’ or uneven), and it offers the advantage of added throughput without introducing the necessity for padding tokens.
Within the figure above, we use ragged batching to mix two full prompts together, but we are able to batch as many as memory allows. The one limit is , the variety of tokens we are able to slot in a batch, with depending on the available memory on the GPU.
Ragged batching is certainly one of the important thing components of continuous batching. To maximise throughput, we are able to mix prefill and decoding sequences following an algorithm like this:
- We attempt to all the time reach our memory budget of tokens per batch
- We first add all of the prompts in decoding phase to the batch, each accounting for 1 token
- We fill the remaining space with prefill phase prompts, counting on the flexibleness of chunked prefill to separate inputs as needed
Dynamic scheduling is the ultimate piece that contributes to the continuous batching technique: we remove finished prompts from the batch as soon as they’re done, and replace them with recent chunked prompts that correspond to incoming requests.
This mix of ragged batching and dynamic scheduling is named continuous batching, and it is the technique that powers modern LLM serving systems.
Conclusion
Continuous batching combines three key techniques to maximise throughput in LLM serving:
- KV caching to avoid recomputing past token representations
- Chunked prefill to handle variable-length prompts inside memory constraints
- Ragged batching with dynamic scheduling to eliminate padding waste and keep the GPU fully utilized
By removing the batch dimension and using attention masks to regulate token interactions, continuous batching allows mixing prefill and decode phases in the identical batch, dramatically improving efficiency for serving multiple requests. For this reason services like ChatGPT can handle hundreds of concurrent users efficiently.
In the subsequent article on this series, we’ll explore efficient KV cache management through paged attention. When you’d prefer to see a deep dive on other continuous batching topics, please tell us within the comments!
Acknowledgement: because of Arthur Zucker for producing the initial concept for the figures utilized in this text. And because of Arthur Zucker, Luc Georges, Lysandre Debut, Merve Noyan and Pedro Cuenca for all providing helpful reviews.















