Optimizing Token Generation in PyTorch Decoder Models

-

which have pervaded nearly every facet of our day by day lives are autoregressive decoder models. These models apply compute-heavy kernel operations to churn out tokens one after the other in a way that, at first glance, seems extremely inefficient. Given the large demand for generative AI, it is not any surprise that extraordinary engineering effort is being invested into its optimization. Whether it’s through custom CUDA kernels, CUDA Graphs, dedicated AI accelerators, or speculative sampling — any technique that reduces latency and/or cost by even a fraction of a percentage is a win.

On this post, we exhibit a method for optimizing token generation in PyTorch using CUDA stream interleaving. While easy to implement, the strategy addresses a selected, often ignored bottleneck and might result in meaningful performance boosts. While pipelining model execution using CUDA streams is common in AI systems engineering, we didn’t find any tutorial documenting the precise PyTorch-level application we describe here. When you find the technique useful, please be so kind as to reference this post.

To facilitate our discussion, we’ll use a straightforward GPT-2 PyTorch decoder model from HuggingFace’s transformers (v5.1.0) library. We are going to run our experiments on an NVIDIA L40S GPU and PyTorch (2.10.0).

Disclaimer: The code we’ll share is meant for demonstrative purposes. Please don’t depend on its accuracy or optimality. Please don’t interpret our mentions of any library, platform, or service as an endorsement of its use.

Importantly, the worth of the CUDA stream-based method we’ll discuss can vary greatly based on the main points of your model and runtime environment. Please you’ll want to run your personal benchmarks before integrating its use.

Our focus on this post is on PyTorch-native inference workloads which remain extremely prevalent in development and test settings. Nevertheless, it is necessary to notice that for production environments dedicated LLM inference libraries resembling vLLM or NVIDIA TensorRT-LLM are likely to deliver greater performance and ought to be used at any time when relevant.

A Toy GPT-2 Model

To simplify our discussion, we’ll use a GPT-2 decoder model from the HuggingFace transformers library and have it run autoregressively on a batch of empty prompts.

In the next code block, we initialize the model and define a naive token generation function that creates a batch of random streams as much as a given length.

import torch
from transformers import GPT2LMHeadModel, GPT2Config

torch.set_float32_matmul_precision('high')

DEVICE = "cuda"

# define the decoder model
config = GPT2Config.from_pretrained("gpt2")
model = GPT2LMHeadModel(config).to(DEVICE).eval()


@torch.inference_mode()
def generate_sequence(model, max_seqlen, batch_size):
    # Initialize prompts with BOS token
    all_tokens = torch.full(
        (batch_size, 1),
        config.bos_token_id,
        device=DEVICE,
        dtype=torch.long
    )
    finished = torch.zeros(batch_size, device=DEVICE, dtype=torch.bool)
    
    for i in range(max_seqlen):
        outputs = model(all_tokens)
        # extract recent token
        logits = outputs.logits[:, -1, :]
        new_tokens = torch.argmax(logits, dim=-1)
        # append recent token to sequence
        all_tokens = torch.cat(
            [all_tokens, new_tokens.unsqueeze(-1)],
            dim=-1
        )
        finished |= (new_tokens == config.eos_token_id)
        stop_gpu = torch.all(finished)
        
        # checking stop condition
        if stop_gpu.item():
            print(f"All sequences finished at step {i+1}")
            break
    
    return all_tokens

Next, we define a straightforward benchmarking function which we use to measure the runtime performance and memory utilization of our token generator in several scenarios.

import time, statistics


def benchmark(func, num_runs=10):
    # Warmup
    func()
    torch.cuda.synchronize()
    
    runtimes = []
    
    for _ in range(num_runs):
        # reset memory stats before each run
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
        
        start = time.perf_counter()
        _ = func()
        torch.cuda.synchronize()
        end = time.perf_counter()
        
        runtimes.append(end - start)
    
    # Get memory allocator stats from last run
    mem_stats = torch.cuda.memory_stats()
    allocated_peak = mem_stats.get('allocated_bytes.all.peak', 0)
    reserved_peak = mem_stats.get('reserved_bytes.all.peak', 0)
    f_peak = reserved_peak - allocated_peak
    f_pct = (
        100 * f_peak / reserved_peak
        if reserved_peak > 0 else 0
    )
    
    print(f"n{'='*60}")
    print(f"Runtime Results:")
    print(f" Mean:               {statistics.mean(runtimes):.4f}s")
    print(f" Std:                {statistics.stdev(runtimes):.4f}s")
    print(f" Min:                {min(runtimes):.4f}s")
    print(f" Max:                {max(runtimes):.4f}s")

    print(f"nMemory Stats:")
    print(f" Allocated bytes (peak): {allocated_peak / 1e9:.3f} GB")
    print(f" Reserved bytes (peak):  {reserved_peak / 1e9:.3f} GB")
    print(f" Fragmentation (peak):   {f_peak / 1e9:.3f} GB ({f_pct:.1f}%)")
    print(f"{'='*60}n")


batch_size = 32
for max_seqlen in [100, 200, 400]:
    print(
        f"Benchmarking generation with batch size {batch_size} "
        f"and max sequence length {max_seqlen}..."
    )
    benchmark(
        lambda: generate_sequence(
            model, max_seqlen=max_seqlen, batch_size=batch_size
        )
    )

Within the table below we capture the outcomes for a batch size of 32 and several other different sequence lengths:

Baseline Results (By Creator)

Because the sequence length doubles, the runtime quadruples — appearing to follow a classic  scaling pattern. Moreover, high memory fragmentation points to severe strain on the CUDA memory allocator, which can lead to frequent memory faults and degrade runtime performance. The fragmentation results from each step asking for barely larger tensor allocations, a pattern which finally ends up leaving multiple pockets of unusable memory.

Our first optimization, KV caching, addresses the runtime complexity of our decoder model.

KV Caching

Our naive generator is amazingly inefficient — slightly than storing and reusing the intermediate tensors from previous tokens, it recalculates your complete sequence at every step.

We address the computation inefficiency through the use of KV caching: We store and reuse the intermediate Key and Value tensors for previous tokens. KV caching reduces the runtime complexity of token generation from  to .

In the next code block, we utilize the transformers library’s built-in support for KV caching to reprogram our token generation function to compute a single batch of tokens in each step.

@torch.inference_mode()
def generate_sequence(model, max_seqlen, batch_size, use_cache=False):
    # Initialize prompts with BOS token
    all_tokens = torch.full(
        (batch_size, 1),
        config.bos_token_id,
        device=DEVICE,
        dtype=torch.long
    )
    finished = torch.zeros(batch_size, device=DEVICE, dtype=torch.bool)

    # past_key_values is used to store the cached key/values for every layer
    past_key_values = None

    for i in range(max_seqlen):
        current_input = (
            all_tokens if past_key_values is None
            else all_tokens[:, -1:]
        )
        outputs = model(
            current_input,
            past_key_values=past_key_values,
            use_cache=use_cache
        )
        # update cache for next step
        past_key_values = outputs.past_key_values
        logits = outputs.logits[:, -1, :]
        new_tokens = torch.argmax(logits, dim=-1)
        # append recent token to sequence
        all_tokens = torch.cat(
            [all_tokens, new_tokens.unsqueeze(-1)],
            dim=-1
        )
        finished |= (new_tokens == config.eos_token_id)
        stop_gpu = torch.all(finished)
        
        # checking stop condition
        if stop_gpu.item():
            print(f"All sequences finished at step {i+1}")
            break
    
    return all_tokens

The resulting performance numbers are captured in the next table:

Token Generation With KV Caching (By Creator)

The performance improvement is profound and, as expected, increases as a function of the sequence length.

Although somewhat higher than in our baseline experiment, the degree of memory fragmentation stays a priority. To deal with this we explore two methods, expandable memory allocations and static KV caching.

Expandable CUDA Memory Allocations

To cut back CUDA memory fragmentation, we program PyTorch to make use of expandable memory segments. As of the time of this writing, this memory optimization is an experimental feature and ought to be used with caution. Please see the PyTorch documentation for details. To make use of the feature we set the next environment variable:

export PYTORCH_ALLOC_CONF="expandable_segments:True"

Rerunning our benchmark ends in the next table:

KV Caching With Expandable Memory Segments (By Creator)

Not only can we see a marked improvement in fragmentation, but we also get an extra (marginal) improvement in runtime performance.

KV Caching With StaticCache

The default cache in HuggingFace is dynamic — it grows because the variety of keys and values increases through the generation progresses. HuggingFace supports a fixed-size cache, StaticCache, which pre-allocates a maximum cache size for the KV pairs and reduces strain on the CUDA memory allocator. The drawback of using StaticCache is that the total length of the cache participates in the eye computation at each token generation step, where irrelevant tokens are masked out. This ends in a waste of computation that grows with the sequence length. For instance, when generating a sequence of 400 tokens, the eye computation for every token might be run on full X-sized tensors.

Within the code block below we enhance our sequence generator to support the usage of a StaticCache:

che:

from transformers import StaticCache

@torch.inference_mode()
def generate_sequence(
    model, max_seqlen, batch_size, use_cache=False, use_static_cache=False
):
    # Initialize prompts with BOS token
    all_tokens = torch.full(
        (batch_size, 1),
        config.bos_token_id,
        device=DEVICE,
        dtype=torch.long
    )
    finished = torch.zeros(batch_size, device=DEVICE, dtype=torch.bool)
    
    # Initialize static cache if requested
    if use_cache and use_static_cache:
        past_key_values = StaticCache(
            config=config,
            max_batch_size=batch_size,
            max_cache_len=max_seqlen,
            device=DEVICE,
            dtype=model.dtype
        )
    else:
        past_key_values = None
    
    # Initialize cache position tracking for static cache
    cache_positions = torch.arange(max_seqlen, device=DEVICE)
    
    for i in range(max_seqlen):
        current_input = (
            all_tokens if past_key_values is None
            else all_tokens[:, -1:]
        )
        cache_position = (
            cache_positions[i:i+1] if use_static_cache else None
        )
        outputs = model(
            current_input,
            past_key_values=past_key_values,
            cache_position=cache_position,
            use_cache=use_cache
        )
        # update cache for next step
        past_key_values = outputs.past_key_values
        logits = outputs.logits[:, -1, :]
        new_tokens = torch.argmax(logits, dim=-1)
        # append recent token to sequence
        all_tokens = torch.cat(
            [all_tokens, new_tokens.unsqueeze(-1)],
            dim=-1
        )
        finished |= (new_tokens == config.eos_token_id)
        stop_gpu = torch.all(finished)
        
        # checking stop condition
        if stop_gpu.item():
            print(f"All sequences finished at step {i+1}")
            break
    
    return all_tokens

The updated results are captured below:

Token Generation With Static KV Cache (By Creator)

Using a fixed-sized cache greatly improves memory utilization as indicated by the decrease in memory fragmentation. Nevertheless, its impact on runtime performance is mixed — for 100 tokens it reduces performance in comparison with a dynamic cache, whereas for 200 and 400 tokens it boosts performance by 9% and 10%, respectively.

There are more advanced methods of implementing attention that optimize for memory utilization without the fee of wasted computation. In a previous post, Optimizing Transformer Models for Variable-Length Input Sequences, we covered some PyTorch techniques for computing attention to cut back computation waste. For production settings, libraries resembling vLLM use PagedAttention for maximizing memory utilization. These methods are outside the scope of this post.

For more details on caching in HuggingFace, please see the caching strategies overview.

Model Compilation

One in every of the documented benefits of using a fixed-sized cache is that it allows for profiting from many just-in-time (JIT) optimizations.

In the next code block we apply our benchmark to a PyTorch-compiled version of our decoder model:

batch_size = 32
max_seqlen = 100

model = torch.compile(model)

benchmark(
    lambda: generate_sequence(
        model,
        max_seqlen=max_seqlen,
        batch_size=batch_size,
        use_cache=True,
        use_static_cache=True
    )
)

Model compilation ends in an extra boost to runtime performance as shown within the table below:

Token Generation With torch.compile (By Creator)

Note that we are able to apply model compilation when using dynamic caching, as well. Nevertheless, torch.compile provides the perfect results when the computation graph consists of fixed-sized tensors (e.g., see here for more details).

The Performance Penalty of Early Stopping

An integral a part of common token generators is checking for the end-of-sequence (EOS) at the top of every step. Without this test, token generators would all the time run for even when all of the sequences within the batch have ended. This might end in considerable computation waste and unnecessary latency — especially when common sequence lengths are much shorter than the utmost length. Within the case of our toy experiment, we wait for all of the sequences within the batch to finish and discontinue token generation. Production-grade implementations will commonly perform  — replacing accomplished sequences with recent prompts on the input queue.

        finished |= (new_tokens == config.eos_token_id)
        stop_gpu = torch.all(finished)
        
        # checking stop condition
        if stop_gpu.item():
            print(f"All sequences finished at step {i+1}")
            break

Importantly, the .item() call on the stop_gpu tensor, triggers a blocking host-device synchronization event. More specifically, to be able to evaluate the conditional  statement, the CPU must wait for the GPU to finish its computation and duplicate the contents of the tensor to host memory. While the CPU waits, it’s blocked from executing the subsequent step of the token generation loop, or more accurately, it’s blocked from loading the subsequent computation kernels onto the GPU.

To measure the impact of the stopping condition on runtime performance, we add instrumentation for performance profiling with NVIDIA Nsight™ Systems (nsys) using the torch.cuda.profiler and nvtx (v0.2.14) APIs. (See our recent post for more details on performance profiling with nsys).

ore details on performance profiling with nsys).

import nvtx
from torch.cuda import profiler

@torch.inference_mode()
def generate_sequence(
    model, max_seqlen, batch_size, use_cache=False, use_static_cache=False
):
    # Initialize prompts with BOS token
    all_tokens = torch.full(
        (batch_size, 1),
        config.bos_token_id,
        device=DEVICE,
        dtype=torch.long
    )
    finished = torch.zeros(batch_size, device=DEVICE, dtype=torch.bool)
    
    # Initialize static cache if requested
    if use_cache and use_static_cache:
        past_key_values = StaticCache(
            config=config,
            max_batch_size=batch_size,
            max_cache_len=max_seqlen,
            device=DEVICE,
            dtype=model.dtype
        )
    else:
        past_key_values = None
    
    # Initialize cache position tracking for static cache
    cache_positions = torch.arange(max_seqlen, device=DEVICE)
    
    for i in range(max_seqlen):
        if i == 30:
            # start nsys profiler
            torch.cuda.synchronize()
            profiler.start()
        elif i == 50:
            # stop nsys profiler
            torch.cuda.synchronize()
            profiler.stop()
        with nvtx.annotate(f"Step {i+1}", color="blue"):
            with nvtx.annotate("Model Forward", color="green"):
                current_input = (
                    all_tokens if past_key_values is None
                    else all_tokens[:, -1:]
                )
                cache_position = (
                    cache_positions[i:i+1] if use_static_cache else None
                )
                outputs = model(
                    current_input,
                    past_key_values=past_key_values,
                    cache_position=cache_position,
                    use_cache=use_cache
                )
                past_key_values = outputs.past_key_values
                logits = outputs.logits[:, -1, :]
                new_tokens = torch.argmax(logits, dim=-1)
                                all_tokens = torch.cat(
                    [all_tokens, new_tokens.unsqueeze(-1)],
                    dim=-1
                )
                finished |= (new_tokens == config.eos_token_id)
                stop_gpu = torch.all(finished)
            with nvtx.annotate("Check Stop Condition", color="red"):
                # checking stop condition
                if stop_gpu.item():
                    print(f"All sequences finished at step {i+1}")
                    break
    
    return all_tokens

We run our script using the  option to begin and stop the profiler programmatically. Please see the official documentation for full details on profiling from the nsys CLI.

nsys profile 
  --capture-range=cudaProfilerApi 
  --trace=cuda,nvtx,osrt 
  --output=baseline 
  python train.py

The next trace, captured for a batch size of 16 and sequence length of 100, shows the GPU idling for about 110 microseconds in between steps — an eternity within the context of high-performance GPU workloads. It is a direct results of the synchronization event triggered by the EOS test.

GPU Utilization Drops Between Each Step (By Creator)

In production-grade implementations such synchronization issues are avoided by some combination of 1) use of lower level (e.g., C/C++) code that avoids the limitation of the Python interpreter, 2) using CUDA graphs to cut back overhead of kernel loading, 3) moving conditional checks onto the GPU using conditional nodes, and 4) repeatedly and asynchronously preparing subsequent requests while the EOS check is in progress.

In the subsequent section, we exhibit a method for hiding the overhead of the host-device synchronization in PyTorch using CUDA streams.

A CUDA Stream Optimization

A CUDA stream is a linear sequence of operations (kernels, memory copies, etc.) that execute so as on the GPU. While operations  a single stream are guaranteed to execute sequentially, operations in several streams can execute concurrently or overlap.

In previous posts (e.g., here and here) we demonstrated the usage of CUDA streams in pipelining common AI/ML workloads, e.g., executing a model on batch  while preparing batch . On this post we’ll use CUDA streams to enable the CPU to load the GPU kernels of step before checking the stopping criteria of step . Contrary to our previous demonstrations of CUDA streams, our current example won’t necessarily involve concurrent GPU kernel execution.
We implement an alternate token generation function that interleaves two CUDA streams, running the next operations iteratively:

Program stream to: (A) wait for stream ( to finish its generation of token B) use the updated tensors to calculate the token , (C) run the EOS test for token  on the GPU, and (D) perform a (non-blocking) copy of the EOS test result to pinned memory on the CPU.

On the default CUDA stream, wait for stream ( to finish its generation of token

On the default CUDA stream, check if the stopping criteria for token were met. In that case, halt the generator and return. Otherwise, increment  and return to step 1.

Whereas previously, the initialization of token  generation was blocked by the EOS test on token , the usage of CUDA streams allows us to program the generation of token  before we check the results of the EOS test on token . In practice, the EOS test for token  on the CPU runs while the GPU is computing token .

@torch.inference_mode()
def generate_sequence_pipelined(
    model,
    max_seqlen,
    batch_size,
    use_cache=False,
    use_static_cache=False
):
    # Initialize prompts with BOS token
    all_tokens = torch.full(
        (batch_size, 1),
        config.bos_token_id,
        device=DEVICE,
        dtype=torch.long
    )
    finished = torch.zeros(batch_size, device=DEVICE, dtype=torch.bool)
    past_key_values = None
    
    # Initialize static cache if requested
    if use_cache and use_static_cache:
        past_key_values = StaticCache(
            config=config,
            max_batch_size=batch_size,
            max_cache_len=max_seqlen,
            device=DEVICE,
            dtype=model.dtype
        )
    
    # Initialize cache position tracking for static cache
    cache_positions = torch.arange(max_seqlen, device=DEVICE)
    
    # Dual streams for pipelining
    streams = [torch.cuda.Stream(), torch.cuda.Stream()]
    stop_host = [
        torch.tensor(False, pin_memory=True),
        torch.tensor(False, pin_memory=True)
    ]
    
    for i in range(max_seqlen):
        curr_idx, prev_idx = i % 2, (i+1) % 2
        curr_s, prev_s = streams[curr_idx], streams[prev_idx]
        
        # Launch iteration i in current stream
        with torch.cuda.stream(curr_s):
            # program stream to attend for previous stream to finish
            curr_s.wait_stream(prev_s)
            current_input = (
                all_tokens if past_key_values is None
                else all_tokens[:, -1:]
            )
            cache_position = (
                cache_positions[i:i+1] if use_static_cache else None
            )
            outputs = model(
                current_input,
                past_key_values=past_key_values,
                cache_position=cache_position,
                use_cache=use_cache
            )
            past_key_values = outputs.past_key_values
            logits = outputs.logits[:, -1, :]
            new_tokens = torch.argmax(logits, dim=-1)
            all_tokens = torch.cat(
                [all_tokens, new_tokens.unsqueeze(-1)],
                dim=-1
            )
            
            finished |= (new_tokens == config.eos_token_id)
            stop_gpu = torch.all(finished)
            stop_host[curr_idx].copy_(stop_gpu, non_blocking=True)
        
        # Check previous iteration's stop signal
        torch.cuda.current_stream().wait_stream(prev_s)
        if stop_host[prev_idx].item():
            print(f"All sequences finished at step {i}")
            break
    
    return all_tokens

The image below captures the nsys trace for our recent token generator:

Constant GPU Activity When Applying CUDA Streams (By Creator)

Within the CUDA section of the trace we are able to see the usage of two CUDA streams, with token generation being passed backwards and forwards in a form of ping-pong effect: One stream generates all the odd tokens and second all the even tokens. The CPU is about half a step ahead of the GPU — allowing it to program step  while the GPU is computing step . The CPU-side EOS stop-check of step (in red) occurs after step  is fully programmed (and has began running).Most significantly, we now find the GPU utilization to be consistent — the idling we saw before is gone.

The CUDA stream interleaving ends in an extra performance boost, as shown within the table below:

Token Generation With CUDA Streams (By Creator)

We’d expect the good thing about the ping-pong solution we now have implemented to be impacted by the ratio between the GPU idle time (i.e., the overhead of kernel loading) and the kernel computation time. To check this, we fix the sequence length at 100 and rerun the benchmark for plenty of batch sizes:

Impact of Pipelining for Various Batch Size (By Creator)

As expected, the best performance gain, 11.6%, occurs when the batch size is smallest and the kernel computation load is at its lowest. Because the kernel compute increases, the ratio of kernel loading to kernel compute time decreases as does the impact of CUDA stream interleaving.

Note that there’s some overhead to the usage of CUDA streams. This might be demonstrated by comparing our interleaving solution to a token generator that skips the EOS test altogether:

Overhead of CUDA Stream Interleaving (By Creator)

The Potential Performance Pitfalls of Using CUDA Streams

CUDA streams ought to be used with extreme caution. When using the default stream we are able to depend on PyTorch to perform any crucial synchronization when data is moved around. Nevertheless, when using CUDA streams, we must ensure appropriate synchronization explicitly. Particularly, we must ensure appropriate data transfer between the streams. Otherwise, we may experience CUDA errors (e.g., “device-side assert triggered”) — if we’re lucky. If we’re less lucky, we may experience data corruption without even knowing it. See the PyTorch CUDA stream documentation for more details on appropriate use.

For AI/ML workloads with large CUDA memory utilization, resembling LLMs, one other consideration is memory utilization. The PyTorch caching allocator manages memory on a per-stream basis; using multiple streams can result in increased memory reservation and fragmentation. These could end in increased memory faults which may overshadow the potential gains from the usage of streams.

Results

Within the table below we summarize the runtime results of applying static caching, compilation, and pipelining on a batch of 32 sequences and a maximum sequence length of 100. The outcomes are sorted in increasing order of performance:

Token Generation Optimization Results (By Creator)

Within the case of our toy GPT-2 model, the perfect results — nearly 5 times the baseline performance — are achieved when employing PyTorch compilation and the CUDA stream interleaving method discussed on this post. Nevertheless, as we now have seen, the impact of CUDA interleaving could vary greatly based on the properties of the workload and runtime environment, particularly on the ratio between the kernel loading time and the kernel compute time. Please you’ll want to run your personal benchmarks before adopting this method.

Summary

In high-performance AI engineering, any hint of GPU under-utilization presents a possibility for optimization. One in every of the first optimization tools on NVIDIA GPUs is CUDA streams. On this post, we demonstrated their use in solving the idle GPU time that results from the host-device synchronization related to early-stopping in PyTorch-native autoregressive token generation. By interleaving CUDA streams in a “ping-pong” pattern, we successfully hid the latency imposed by the EOS-check which resulted in a meaningful increase the workload’s throughput. By combining this system with the well-known methods of model compilation and static caching, we are able to maximize the performance of PyTorch-native inference.

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x