Optimizing PyTorch Model Inference on AWS Graviton

-

AI/ML models will be an especially expensive endeavor. A lot of our posts have been focused on a wide range of suggestions, tricks, and techniques for analyzing and optimizing the runtime performance of AI/ML workloads. Our argument has been twofold:

  1. Performance evaluation and optimizations have to be an integral technique of every AI/ML development project, and,
  2. Reaching meaningful performance boosts and price reduction doesn’t require a high degree of specialization. Any AI/ML developer can do it. Every AI/ML developer should do it.

, we addressed the challenge of optimizing an ML inference workload on an Intel® Xeon® processor. We began by reviewing plenty of scenarios through which a CPU may be the perfect alternative for AI/ML inference even in an era of multiple dedicated AI inference chips. We then introduced a toy image-classification PyTorch model and proceeded to exhibit a large variety of techniques for enhancing its runtime performance on an Amazon EC2 c7i.xlarge instance, powered by 4th Generation Intel Xeon Scalable processors. On this post, we extend our discussion to AWS’s homegrown Arm-based Graviton CPUs. We are going to revisit most of the optimizations we discussed in our previous posts — a few of which can require adaptation to the Arm processor — and assess their impact on the identical toy model. Given the profound differences between the Arm and Intel processors, the paths to the perfect performing configuration may take different paths.

AWS Graviton

AWS Graviton is a family of processors based on Arm Neoverse CPUs, which can be customized and built by AWS for optimal price-performance and energy efficiency. Their dedicated engines for vector processing (NEON and SVE/SVE2) and matrix multiplication (MMLA), and their support for Bfloat16 operations (as of Graviton3), make them a compelling candidate for running compute intensive workloads comparable to AI/ML inference. To facilitate high-performance AI/ML on Graviton, your complete software stack has been optimized for its use:

  • Low-Level Compute Kernels from the Arm Compute Library (ACL) are highly optimized to leverage the Graviton hardware accelerators (e.g., SVE and MMLA).
  • ML Middleware Libraries comparable to oneDNN and OpenBLAS route deep learning and linear algebra operations to the specialized ACL kernels.
  • AI/ML Frameworks like PyTorch and TensorFlow are compiled and configured to make use of these optimized backends.

On this post we are going to use an Amazon EC2 c8g.xlarge instance powered by 4 AWS Graviton4 processors and an AWS ARM64 PyTorch Deep Learning AMI (DLAMI).

The intention of this post is to exhibit suggestions for enhancing performance on an AWS Graviton instance. Importantly, our intention is not to attract a comparison between AWS Graviton and alternative chips, neither is it to advocate for the usage of one chip over the opposite. The perfect alternative of processor relies on an entire bunch of considerations beyond the scope of this post. One in every of the vital considerations shall be the utmost runtime performance of your model on each chip. In other words: how much “bang” can we get for our buck? Thus, making an informed decision about the perfect processor is one in every of the motivations for optimizing runtime performance on every one.

One other motivation for optimizing our model’s performance for multiple inference devices, is to extend its portability. The playing field of AI/ML is amazingly dynamic and resilience to changing circumstances is crucial for fulfillment. It just isn’t unusual for compute instances of certain types to suddenly turn out to be unavailable or scarce. Conversely, a rise in capability of AWS Graviton instances, could imply their availability at steep discounts, e.g., within the Amazon EC2 Spot Instance market, presenting cost-savings opportunities that you simply wouldn’t need to miss out on.

Disclaimers

The blocks code of code we are going to share, the optimization steps we are going to discuss, and the outcomes we are going to reach, are intended for instance of the advantages chances are you’ll see from ML performance optimization on an AWS Graviton instance. These may differ considerably from the outcomes you would possibly see together with your own model and runtime environment. Please don’t depend on the accuracy or optimality of the contents of this post. Please don’t interpret the mention of any library, framework, or platform as an endorsement of its use.

Inference Optimization on AWS Graviton

As in our previous post, we are going to exhibit the optimization steps on a toy image classification model:

import torch, torchvision
import time


def get_model(channels_last=False, compile=False):
    model = torchvision.models.resnet50()

    if channels_last:
        model= model.to(memory_format=torch.channels_last)

    model = model.eval()

    if compile:
        model = torch.compile(model)

    return model

def get_input(batch_size, channels_last=False):
    batch = torch.randn(batch_size, 3, 224, 224)
    if channels_last:
        batch = batch.to(memory_format=torch.channels_last)
    return batch

def get_inference_fn(model, enable_amp=False):
    def infer_fn(batch):
        with torch.inference_mode(), torch.amp.autocast(
                'cpu',
                dtype=torch.bfloat16,
                enabled=enable_amp
        ):
            output = model(batch)
        return output
    return infer_fn

def benchmark(infer_fn, batch):
    # warm-up
    for _ in range(20):
        _ = infer_fn(batch)

    iters = 100

    start = time.time()
    for _ in range(iters):
        _ = infer_fn(batch)
    end = time.time()

    return (end - start) / iters


batch_size = 1
model = get_model()
batch = get_input(batch_size)
infer_fn = get_inference_fn(model)
avg_time = benchmark(infer_fn, batch)
print(f"nAverage samples per second: {(batch_size/avg_time):.2f}")

The initial throughput is 12 samples per second (SPS).

Upgrade to the Most Recent PyTorch Release

Whereas the version of PyTorch in our DLAMI is 2.8, the newest version of PyTorch, on the time of this writing, is 2.9. Given the rapid pace of development in the sphere of AI/ML, it is extremely advisable to make use of the most modern library packages. As our first step, we upgrade to PyTorch 2.9 which includes key updates to its Arm backend.

pip3 install -U torch torchvision --index-url https://download.pytorch.org/whl/cpu

Within the case of our model in its initial configuration, upgrading the PyTorch version doesn’t have any effect. Nonetheless, this step is crucial for getting essentially the most out of the optimization techniques that we’ll assess.

Batched Inference

To cut back the overhead of launching overheads and increase the utilization of the HW accelerators, we group together samples and apply batched inference. The table below demonstrates how the model throughput varies as a function of batch size:

Inference Throughput for Various Batch Sizes (by Writer)

Memory Optimizations

We apply plenty of techniques from our previous post for optimizing memory allocation and usage. These include the channels-last memory formatautomatic mixed precision with the bfloat16 data type (supported from Graviton3), the TCMalloc allocation library, and big page allocation. Please see the  for details. We also enable the  mode of the ACL GEMM kernels, and caching of the kernel primitives — two optimizations that appear within the official guidelines for running PyTorch inference on Graviton.

The command line instructions required to enable these optimizations are shown below:

# install TCMalloc
sudo apt-get install google-perftools

# Program the usage of TCMalloc
export LD_PRELOAD=/usr/lib/aarch64-linux-gnu/libtcmalloc.so.4

# Enable huge page memory allocation
export THP_MEM_ALLOC_ENABLE=1

# Enable the fast math mode of the GEMM kernels
export DNNL_DEFAULT_FPMATH_MODE=BF16

# Set LRU Cache capability to cache the kernel primitives
export LRU_CACHE_CAPACITY=1024

The next table captures the impact of the memory optimizations, applied successively:

ResNet-50 Memory Optimization Results (by Writer)

Within the case of our toy model, the and optimizations had the best impact. After applying the entire memory optimizations, the typical throughput is 53.03 SPS.

Model Compilation

The support of PyTorch compilation for AWS Graviton is an area of focused effort of the AWS Graviton team. Nonetheless, within the case of our toy model, it leads to a slight reduction in throughput, from 53.03 SPS to 52.23.

Multi-Employee Inference

While typically applied in settings with many greater than 4 vCPUs, we exhibit the implementation of multi-worker inference by modifying our script to support core pinning:

if __name__ == '__main__':
    # pin CPUs in line with employee rank
    import os, psutil
    rank = int(os.environ.get('RANK','0'))
    world_size = int(os.environ.get('WORLD_SIZE','1'))
    cores = list(range(psutil.cpu_count(logical=True)))
    num_cores = len(cores)
    cores_per_process = num_cores // world_size
    start_index = rank * cores_per_process
    end_index = (rank + 1) * cores_per_process
    pid = os.getpid()
    p = psutil.Process(pid)
    p.cpu_affinity(cores[start_index:end_index])

    batch_size = 8
    model = get_model(channels_last=True)
    batch = get_input(batch_size, channels_last=True)
    infer_fn = get_inference_fn(model, enable_amp=True)
    avg_time = benchmark(infer_fn, batch)
    print(f"nAverage samples per second: {(batch_size/avg_time):.2f}")

We note that contrary to other AWS EC2 CPU instance types, each Graviton vCPU maps on to a single physical CPU core. We use the torchrun utility to begin up 4 employees, with each running on a single CPU core:

export OMP_NUM_THREADS=1 #set one OpenMP thread per employee
torchrun --nproc_per_node=4 primary.py

This leads to a throughput of 55.15 SPS, a 4% improvement over our previous best result.

INT8 Quantization for Arm

One other area of lively development and continuous improvement on Arm is INT8 quantization. INT8 quantization tools are typically heavily tied to the goal instance type. In our previous post we demonstrated PyTorch 2 Export Quantization with X86 Backend through Inductor using the TorchAO (0.12.1) library. Fortunately, recent versions of TorchAO include a dedicated quantizer for Arm. The updated quantization sequence is shown below. As in our previous post we have an interest just within the potential performance impact. In practice, INT8 quantization can have a major impact on the standard of the model and should necessitate a more sophisticated quantization strategy.

from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e
import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as aiq

def quantize_model(model):
    x = torch.randn(4, 3, 224, 224).contiguous(
                            memory_format=torch.channels_last)
    example_inputs = (x,)
    batch_dim = torch.export.Dim("batch")
    with torch.no_grad():
        exported_model = torch.export.export(
            model,
            example_inputs,
            dynamic_shapes=((batch_dim,
                             torch.export.Dim.STATIC,
                             torch.export.Dim.STATIC,
                             torch.export.Dim.STATIC),
                            )
        ).module()
    quantizer = aiq.ArmInductorQuantizer()
    quantizer.set_global(aiq.get_default_arm_inductor_quantization_config())
    prepared_model = prepare_pt2e(exported_model, quantizer)
    prepared_model(*example_inputs)
    converted_model = convert_pt2e(prepared_model)
    optimized_model = torch.compile(converted_model)
    return optimized_model


batch_size = 8
model = get_model(channels_last=True)
model = quantize_model(model)
batch = get_input(batch_size, channels_last=True)
infer_fn = get_inference_fn(model, enable_amp=True)
avg_time = benchmark(infer_fn, batch)
print(f"nAverage samples per second: {(batch_size/avg_time):.2f}")

The resultant throughput is 56.77 SPS for a 7.1% improvement over the bfloat16 solution.

AOT Compilation Using ONNX and OpenVINO

In our previous post, we explored ahead-of-time (AOT) model compilation techniques using Open Neural Network Exchange (ONNX) and OpenVINO. Each libraries include dedicated support for running on AWS Graviton (e.g., see here and here). The experiments on this section require the next library installations:

pip install onnxruntime onnxscript openvino nncf

The next code block demonstrates the model compilation and execution on Arm using ONNX:

def export_to_onnx(model, onnx_path="resnet50.onnx"):
    dummy_input = torch.randn(4, 3, 224, 224)
    batch = torch.export.Dim("batch")
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_shapes=((batch,
                         torch.export.Dim.STATIC,
                         torch.export.Dim.STATIC,
                         torch.export.Dim.STATIC),
                        ),
        dynamo=True
    )
    return onnx_path

def onnx_infer_fn(onnx_path):
    import onnxruntime as ort

    sess = ort.InferenceSession(
        onnx_path,
        providers=["CPUExecutionProvider"]
   )
    sess_options = ort.SessionOptions()
    sess_options.add_session_config_entry(
               "mlas.enable_gemm_fastmath_arm64_bfloat16", "1")
    input_name = sess.get_inputs()[0].name

    def infer_fn(batch):
        result = sess.run(None, {input_name: batch})
        return result
    return infer_fn

batch_size = 8
model = get_model()
onnx_path = export_to_onnx(model)
batch = get_input(batch_size).numpy()
infer_fn = onnx_infer_fn(onnx_path)
avg_time = benchmark(infer_fn, batch)
print(f"nAverage samples per second: {(batch_size/avg_time):.2f}")

It must be noted that the ONNX runtime supports a dedicated ACL-ExecutionProvider for running on Arm, but this requires a custom ONNX construct (as of the time of this writing), which is out of the scope of this post.

Alternatively, we are able to compile the model using OpenVINO. The code block below demonstrates its use, including an option for INT8 quantization using NNCF:

import openvino as ov
import nncf

def openvino_infer_fn(compiled_model):
    def infer_fn(batch):
        result = compiled_model([batch])[0]
        return result
    return infer_fn

class RandomDataset(torch.utils.data.Dataset):
    def __len__(self):
        return 10000

    def __getitem__(self, idx):
        return torch.randn(3, 224, 224)

quantize_model = False
batch_size = 8
model = get_model()
calibration_loader = torch.utils.data.DataLoader(RandomDataset())
calibration_dataset = nncf.Dataset(calibration_loader)

if quantize_model:
    # quantize PyTorch model
    model = nncf.quantize(model, calibration_dataset)

ovm = ov.convert_model(model, example_input=torch.randn(1, 3, 224, 224))
ovm = ov.compile_model(ovm)
batch = get_input(batch_size).numpy()
infer_fn = openvino_infer_fn(ovm)
avg_time = benchmark(infer_fn, batch)
print(f"nAverage samples per second: {(batch_size/avg_time):.2f}")

Within the case of our toy model, OpenVINO compilation leads to a further boost of the throughput to 63.48 SPS, however the NNCF quantization disappoints, leading to just 55.18 SPS.

Results

The outcomes of our experiments are summarized within the table below:

ResNet50 Inference Optimization Results (by Writer)

As in our , we reran our experiments on a second model — a Vision Transformer (ViT) from the timm library — to exhibit how the impact of the runtime optimizations we discussed can vary based on the small print of the model. The outcomes are captured below:

ViT Inference Optimization Results (by Writer)

Summary

On this post, we reviewed plenty of relatively easy optimization techniques and applied them to 2 toy PyTorch models. As the outcomes demonstrated, the impact of every optimization step can vary greatly based on the small print of the model, and the journey toward peak performance can take many various paths. The steps we presented on this post were just an appetizer; there are undoubtedly many more optimizations that may unlock even greater performance.

Along the way in which, we noted the various AI/ML libraries which have introduced deep support for the Graviton architecture, and the seemingly continuous community effort of ongoing optimization. The performance gains we achieved, combined with this apparent dedication, prove that AWS Graviton is firmly within the “big leagues” in terms of running compute-intensive AI/ML workloads.

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x