Overcoming the Hidden Performance Traps of Variable-Shaped Tensors: Efficient Data Sampling in PyTorch

-

is the a part of a series of posts on the subject of analyzing and optimizing PyTorch models. Throughout the series, we have now advocated for using the PyTorch Profiler in AI model development and demonstrated the potential impact of performance optimization on the speed and value of running AI/ML workloads. One common phenomenon we have now seen is how seemingly innocent code can hamper runtime performance. On this post, we explore among the penalties related to the naive use of variable-shaped tensors — tensors whose shape depends on preceding computations and/or inputs. While not applicable to all situations, there are occasions when using variable-shaped tensors will be avoided — although this may increasingly come on the expense of additional compute and/or memory. We’ll exhibit the tradeoffs of those alternatives on a toy implementation of knowledge sampling in PyTorch.

Three Downsides of Variable Shaped Tensors

We motivate the discussion by presenting three disadvantages to using variable-shaped tensors:

Host-Device Sync Events

In an excellent scenario, the CPU and GPU are in a position to run in parallel in an asynchronous manner, with the CPU repeatedly feeding the GPU with input samples, allocating required GPU memory, and loading GPU compute kernels, and the GPU executing the loaded kernels on the provided inputs using the allocated memory. The presence of dynamic-shaped tensors throws a wrench into this parallelism. To be able to allocate the suitable amount memory, the CPU must wait for the GPU to report the tensor’s shape, after which the GPU must wait for the CPU to allocate the memory and proceed with the kernel loading. The overhead of this sync event could cause a drop within the GPU utilization and slow runtime performance.

We saw an example of this in part three of this series once we studied a naive implementation of the common cross-entropy loss that included calls to torch.nonzero and torch.unique. Each APIs return tensors with shapes which might be dynamic and depending on the contents of the input. When these functions are run on the GPU, a host-device synchronization event occurs. Within the case of the cross-entropy loss, we discovered the inefficiency through using PyTorch Profiler and were in a position to easily overcome it with another implementation that avoided using variable-shaped tensors and demonstrated significantly better runtime performance.

Graph Compilation

In a recent post we explored the performance advantages of applying  (JIT) compilation using the torch.compile operator. Considered one of our observations was that graph compilation provided significantly better results when the graph was static. The presence of dynamic shapes within the graph limits the extent of the optimization via compilation: In some cases, it fails completely; in others it leads to lower performance gains. The identical implications also apply to other types of graph compilation, reminiscent of XLA, ONNX, OpenVINO, and TensorRT.

Data Batching

One other optimization we have now encountered in several of our posts (e.g., here) is sample-batching. Batching improves performance in two primary ways:

  1. Reducing overhead of kernel loading: Moderately than loading the GPU kernels required for the computation pipeline once per input sample, the CPU can load the kernels once per batch.
  2. Maximizing parallelization across compute units: GPUs are highly parallel compute engines. The more we’re in a position to parallelize computation, the more we will saturate the GPU and increase its utilization. By batching we will potentially increase the degree of parallelization by an element of the batch size.

Despite their downsides, using variable-shaped tensors is usually unavoidable. But sometimes we will modify our model implementation to avoid them. Sometimes these changes shall be straightforward (as within the cross-entropy loss example). Other times they might require some creativity in coming up with a distinct sequence of fixed-shape PyTorch APIs that provide the identical numerical result. Often, this effort can deliver meaningful rewards in runtime and costs.

In the subsequent sections, we’ll study using variable-shaped tensors within the context of the information sampling operation. We’ll start with a trivial implementation and analyze its performance. We’ll then propose a GPU-friendly alternative that avoids using variable-shaped tensors.

To match our implementations, we’ll use an Amazon EC2 g6e.xlarge with an NVIDIA L40S running an AWS Deep Learning AMI (DLAMI) with PyTorch (2.8). The code we’ll share is meant for demonstration purposes. Please don’t depend on it for accuracy or optimality. Please don’t interpret our mention of any framework, library, or platform and an endorsement of its use.

Sampling in AI Model Workloads

Within the context of this post, sampling refers back to the number of a subset of things from a big set of candidates for the needs of computational efficiency, balancing of datatypes, or regularization. Sampling is common in lots of AI/ML models, reminiscent of detection, rating, and contrastive learning systems.

We define a straightforward variation of the sampling problem: Given a listing of  tensors each with a binary label, we’re asked to return a subset of tensors containing each positive and negative examples, in random order. If the input list comprises enough samples of every label (), the returned subset needs to be evenly split. Whether it is lacking samples of 1 type, these needs to be stuffed with random samples of the second type.

The code block below comprises a PyTorch implementation of our sampling function. The implementation is inspired by the favored Detectron2 library (e.g., see here and here). For the experiments on this post, we’ll fix the sampling ratio to 

import torch

INPUT_SAMPLES = 10000
SUB_SAMPLE = INPUT_SAMPLES // 10
FEATURE_DIM = 16

def sample_data(input_array, labels):
    device = labels.device
    positive = torch.nonzero(labels == 1, as_tuple=True)[0]
    negative = torch.nonzero(labels == 0, as_tuple=True)[0]
    num_pos = min(positive.numel(), SUB_SAMPLE//2)
    num_neg = min(negative.numel(), SUB_SAMPLE//2)
    if num_neg < SUB_SAMPLE//2:
        num_pos = SUB_SAMPLE - num_neg
    elif num_pos < SUB_SAMPLE//2:
        num_neg = SUB_SAMPLE - num_pos

    # randomly select positive and negative examples
    perm1 = torch.randperm(positive.numel(), device=device)[:num_pos]
    perm2 = torch.randperm(negative.numel(), device=device)[:num_neg]

    pos_idxs = positive[perm1]
    neg_idxs = negative[perm2]

    sampled_idxs = torch.cat([pos_idxs, neg_idxs], dim=0)
    rand_perm = torch.randperm(SUB_SAMPLE, device=labels.device)
    sampled_idxs = sampled_idxs[rand_perm]
    return input_array[sampled_idxs], labels[sampled_idxs]

Performance Evaluation With PyTorch Profiler

Even when not immediately obvious, using dynamic shapes is definitely identifiable within the PyTorch Profiler Trace view. We use the next function to enable PyTorch Profiler:

def profile(fn, input, labels):
    
    def export_trace(p):
        p.export_chrome_trace(f"{fn.__name__}.json")
        
    with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA],
            with_stack=True,
            schedule=torch.profiler.schedule(wait=0, warmup=10, lively=5),
            on_trace_ready=export_trace
    ) as prof:
        for _ in range(20):
            fn(input, labels)
            torch.cuda.synchronize()  # explicit sync for trace readability
            prof.step()

# create random input
input_samples = torch.randn((INPUT_SAMPLES, FEATURE_DIM), device='cuda')
labels = torch.randint(0, 2, (INPUT_SAMPLES,), 
                       device='cuda', dtype=torch.int64)

# run with profiler
profile(sample_data, input_samples, labels)

The image below was captured for the worth of ten million input samples. It clearly shows the presence of sync events coming from the torch.nonzero call, in addition to the corresponding drops in GPU utilization:

Profiler Trace of Sampler (by Creator)

The usage of torch.nonzero in our implementation is just not ideal, but can it's avoided?

A GPU-Friendly Data Sampler

We propose another implementation of our sampling function that replaces the dynamic torch.nonzero function with a creative combination of the static torch.count_nonzero, torch.topk, and other APIs:

def opt_sample_data(input, labels):
    pos_mask = labels == 1
    neg_mask = labels == 0
    num_pos_idxs = torch.count_nonzero(pos_mask, dim=-1)
    num_neg_idxs = torch.count_nonzero(neg_mask, dim=-1)
    half_samples = labels.new_full((), SUB_SAMPLE // 2)
    num_pos = torch.minimum(num_pos_idxs, half_samples)
    num_neg = torch.minimum(num_neg_idxs, half_samples)
    num_pos = torch.where(
        num_neg < SUB_SAMPLE // 2,
        SUB_SAMPLE - num_neg,
        num_pos
    )
    num_neg = SUB_SAMPLE - num_pos

    # create random ordering on pos and neg entries
    rand = torch.rand_like(labels, dtype=torch.float32)
    pos_rand = torch.where(pos_mask, rand, -1)
    neg_rand = torch.where(neg_mask, rand, -1)

    # select top pos entries and invalidate others
    # since CPU doesn't know num_pos, we assume maximum to avoid sync
    top_pos_rand, top_pos_idx = torch.topk(pos_rand, k=SUB_SAMPLE)
    arange = torch.arange(SUB_SAMPLE, device=labels.device)
    if num_pos.numel() > 1:
        # unsqueeze to support batched input
        arange = arange.unsqueeze(0)
        num_pos = num_pos.unsqueeze(-1)
        num_neg = num_neg.unsqueeze(-1)
    top_pos_rand = torch.where(arange >= num_pos, -1, top_pos_rand)

    # repeat for neg entries
    top_neg_rand, top_neg_idx = torch.topk(neg_rand, k=SUB_SAMPLE)
    top_neg_rand = torch.where(arange >= num_neg, -1, top_neg_rand)

    # mix and blend together positive and negative idxs
    cat_rand = torch.cat([top_pos_rand, top_neg_rand], dim=-1)
    cat_idx = torch.cat([top_pos_idx, top_neg_idx], dim=-1)
    topk_rand_idx = torch.topk(cat_rand, k=SUB_SAMPLE)[1]
    sampled_idxs = torch.gather(cat_idx, dim=-1, index=topk_rand_idx)
    sampled_input = torch.gather(input, dim=-2, 
                                 index=sampled_idxs.unsqueeze(-1))
    sampled_labels = torch.gather(labels, dim=-1, index=sampled_idxs)
    return sampled_input, sampled_labels

Clearly, this function requires more memory and more operations than our first implementation. The query is: Do the performance advantages of a static, synchronization-free implementation outweigh the additional cost in memory and compute?

To evaluate the tradeoffs between the 2 implementations, we introduce the next benchmarking utility:

def benchmark(fn, input, labels):
    # warm-up
    for _ in range(20):
        _ = fn(input, labels)

    iters = 100
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
    start.record()
    for _ in range(iters):
        _ = fn(input, labels)
    end.record()
    torch.cuda.synchronize()
    avg_time = start.elapsed_time(end) / iters
    
    print(f"{fn.__name__} average step time: {(avg_time):.4f} ms")

benchmark(sample_data, input_samples, labels)
benchmark(opt_sample_data, input_samples, labels)

The next table compares the typical runtime of every of the implementations for a wide range of input sample sizes:

Comparative Step Time Performance — Lower is Higher (by Creator)

For a lot of the input sample sizes, the overhead of the host-device sync event is either comparable or lower than the extra compute of the static implementation. Disappointingly, we only see a serious profit from the sync-free alternative when the input sample size reaches ten million. Sample sizes that enormous are unusual in AI/ML settings. However it’s not our tendency to provide up so easily. As noted above, the static implementation enables other optimizations like graph compilation and input batching.

Graph Compilation

Contrary to the unique function — which fails to compile — our static implementation is fully compatible with torch.compile:

benchmark(torch.compile(opt_sample_data), input_samples, labels)

The next table includes the runtimes of our compiled function:

Comparative Step Time Performance — Lower is Higher (by Creator)

The outcomes are significantly higher — providing a 70–75 percent boost over the unique sampler implementation within the 1–10 thousand range. But we still have yet another optimization up our sleeve.

Maximizing Performance with Batched Input

Because the unique implementation comprises variable-shaped operations, it cannot handle batched input directly. To process a batch, we have now no selection but to use it to every input individually, in a Python loop:

BATCH_SIZE = 32

def batched_sample_data(inputs, labels):
    sampled_inputs = []
    sampled_labels = []
    for i in range(inputs.size(0)):
        inp, lab = sample_data(inputs[i], labels[i])
        sampled_inputs.append(inp)
        sampled_labels.append(lab)
    return torch.stack(sampled_inputs), torch.stack(sampled_labels)

In contrast, our optimized function supports batched inputs as is — no changes mandatory.

input_batch = torch.randn((BATCH_SIZE, INPUT_SAMPLES, FEATURE_DIM),
                          device='cuda')
labels = torch.randint(0, 2, (BATCH_SIZE, INPUT_SAMPLES),
                       device='cuda', dtype=torch.int64)

benchmark(batched_sample_data, input_batch, labels)
benchmark(opt_sample_data, input_batch, labels)
benchmark(torch.compile(opt_sample_data), input_batch, labels)

The table below compares the step times of our sampling functions on a batch size of 32:

Step Time Performance on Batched Input — Lower is Higher (by Creator)

Now the outcomes are definitive: By utilizing a static implementation of the information sampler, we're in a position to boost performance by 2X–52X(!!) the variable-shaped option, depending on the input sample size.

Note that although our experiments were run on a GPU device, the model compilation and input batching optimizations also apply to a CPU environment. Thus, avoiding variable shapes could have implications on AI/ML model performance on CPU, as well.

Summary

The optimization process we demonstrated on this post generalizes beyond the particular case of knowledge sampling:

  • Discovery via Performance Profiling: Using the PyTorch Profiler we were in a position to discover drops in GPU utilization and discover their source: the presence of variable-shaped tensors resulting from the torch.nonzero operation.
  • An Alternate Implementation: Our profiling findings allowed us to develop another implementation that achieved the identical goal while avoiding using variable-shaped tensors. Nonetheless, this step got here at the fee of additional compute and memory overhead. As seen in our initial benchmarks, the sync-free alternative demonstrated worse performance on common input sizes.
  • Unlocking Further Potential for Optimization: The true breakthrough got here since the static-shaped implementation was compilation-friendly and supported batching. These optimizations provided performance gains that dwarfed the initial overhead, resulting in a 2x to 52x speedup over the unique implementation.

Naturally, not all stories will end as happily as ours. In lots of cases, we may come across PyTorch code that performs poorly on the GPU but doesn't have another implementation, or it could have one which requires significantly more compute resources. Nonetheless, given the potential for meaningful gains in performance and reductions in cost, the means of identifying runtime inefficiencies and exploring alternative implementations is a vital a part of AI/ML development.

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x