Increasing Transformer Model Efficiency Through Attention Layer Optimization

-

How paying “higher” attention can drive ML cost savings

Photo by Andrew Seaman on Unsplash

Introduced within the landmark 2017 paper “Attention Is All You Need” (Vaswani et al., 2017), the Transformer architecture is widely considered probably the most influential scientific breakthroughs of the past decade. On the core of the Transformer is the eye mechanism, a novel approach that permits AI models to grasp complex structures by specializing in different parts of input sequences based on the duty at hand. Originally demonstrated on this planet of natural language processing, the success of the Transformers architecture has quickly spread to many other domains, including speech recognition, scene understanding, reinforcement learning, protein structure prediction, and more. Nevertheless, attention layers are highly resource-intensive, and as these layers develop into the usual across increasingly large models, the prices related to their training and deployment have surged. This has created an urgent need for strategies that reduce the computational cost of this core layer in order to extend the efficiency and scalability of Transformer-based AI models.

On this post, we’ll explore several tools for optimizing attention in PyTorch. Our focus will probably be on methods that maintain the accuracy of the eye layer. These will include PyTorch SDPA, FlashAttention, TransformerEngine Attention, FlexAttention, and xFormer attention. Other methods that reduce the computational cost via approximation of the eye calculation (e.g., DeepSpeed’s Sparse Attention, Longformer, Linformer, and more) is not going to be considered. Moreover, we is not going to discuss general optimization techniques that, while helpful to attention performance, aren’t specific to the eye computation itself (e.g., FP8 training, model sharding, and more).

Importantly, attention optimization is an energetic area of research with latest methods coming out on a fairly regular basis. Our goal is to extend your awareness of a number of the existing solutions and give you a foundation for further exploration and experimentation. The code we’ll share below is meant for demonstrative purposes only — we make no claims regarding its accuracy, optimality, or robustness. Please don’t interpret our mention of any platforms, libraries, or optimization techniques as an endorsement for his or her use. The very best options for you’ll depend greatly on the specifics of your individual use-case.

Many due to Yitzhak Levi for his contributions to this post.

To facilitate our discussion, we construct a Vision Transformer (ViT)-backed classification model using the favored timm Python package (version 0.9.7). We are going to use this model as an instance the performance impact of varied attention kernels.

We start by defining a simplified Transformer block that permits for programming the eye function by passing it into its constructor. Since attention implementations assume specific input tensor formats, we also include an option for controlling the format, ensuring compatibility with the eye kernel of our selecting.

# general imports
import os, time, functools

# torch imports
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

# timm imports
from timm.models.vision_transformer import VisionTransformer
from timm.layers import Mlp

IMG_SIZE = 224
BATCH_SIZE = 128

# Define ViT settings
NUM_HEADS = 16
HEAD_DIM = 64
DEPTH = 24
PATCH_SIZE = 16
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 196

class MyAttentionBlock(nn.Module):
def __init__(
self,
attn_fn,
format = None,
dim: int = 768,
num_heads: int = 12,
**kwargs
) -> None:
super().__init__()
self.attn_fn = attn_fn
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.proj = nn.Linear(dim, dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=dim * 4,
)
permute = (2, 0, 3, 1, 4)
self.permute_attn = functools.partial(torch.transpose,dim0=1,dim1=2)

if format == 'bshd':
permute = (2, 0, 1, 3, 4)
self.permute_attn = nn.Identity()
self.permute_qkv = functools.partial(torch.permute,dims=permute)

def forward(self, x_in: torch.Tensor) -> torch.Tensor:
x = self.norm1(x_in)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
# permute tensor based on the desired format
qkv = self.permute_qkv(qkv)
q, k, v = qkv.unbind(0)
# use the eye function specified by the user
x = self.attn_fn(q, k, v)
# permute output in response to the desired format
x = self.permute_attn(x).reshape(B, N, C)
x = self.proj(x)
x = x + x_in
x = x + self.mlp(self.norm2(x))
return x

We define a randomly generated dataset which we’ll use to feed to our model during training.

# Use random data
class FakeDataset(Dataset):
def __len__(self):
return 1000000

def __getitem__(self, index):
rand_image = torch.randn([3, IMG_SIZE, IMG_SIZE],
dtype=torch.float32)
label = torch.tensor(data=index % 1000, dtype=torch.int64)
return rand_image, label

Next, we define our ViT training function. While our example focuses on demonstrating a training workload, it’s crucial to emphasise that optimizing the eye layer is equally, if no more, essential during model inference.

The training function we define accepts the customized Transformer block and a flag that controls the usage of torch.compile.

def train_fn(block_fn, compile):
torch.random.manual_seed(0)
device = torch.device("cuda:0")
torch.set_float32_matmul_precision("high")

# Create dataset and dataloader
train_set = FakeDataset()
train_loader = DataLoader(
train_set, batch_size=BATCH_SIZE,
num_workers=12, pin_memory=True, drop_last=True)

model = VisionTransformer(
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=NUM_HEADS*HEAD_DIM,
depth=DEPTH,
num_heads=NUM_HEADS,
class_token=False,
global_pool="avg",
block_fn=block_fn
).to(device)

if compile:
model = torch.compile(model)

# Define loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters())

model.train()

t0 = time.perf_counter()
summ = 0
count = 0
for step, data in enumerate(train_loader):
# Copy data to GPU
inputs = data[0].to(device=device, non_blocking=True)
label = data[1].to(device=device, non_blocking=True)
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
outputs = model(inputs)
loss = criterion(outputs, label)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()

# Capture step time
batch_time = time.perf_counter() - t0
if step > 20: # Skip first steps
summ += batch_time
count += 1
t0 = time.perf_counter()
if step > 100:
break
print(f'average step time: {summ / count}')

# define compiled and uncompiled variants of our train function
train = functools.partial(train_fn, compile=False)
train_compile = functools.partial(train_fn, compile=True)

Within the code block below we define a PyTorch-native attention function and use it to coach our ViT model:

def attn_fn(q, k, v):
scale = HEAD_DIM ** -0.5
q = q * scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
return x

block_fn = functools.partial(MyAttentionBlock, attn_fn=attn_fn)

print('Default Attention')
train(block_fn)
print('Compiled Default Attention')
train_compile(block_fn)

We ran this on an NVIDIA H100 with CUDA 12.4 and PyTorch 2.5.1. The uncompiled variant resulted in a median step time of 370 milliseconds (ms), while the compiled variant improved to 242 ms. We are going to use these results as a baseline for comparison as we consider alternative solutions for performing the eye computation.

One in all the best ways to spice up the performance of our attention layers in PyTorch is to make use of the scaled_dot_product_attention (SDPA) function. Currently in beta, PyTorch SDPA consolidates multiple kernel-level optimizations and dynamically selects essentially the most efficient one based on the input’s properties. Supported backends (as of now) include: FlashAttention-2, Memory-Efficient Attention, a C++-based Math Attention, and CuDNN. These backends fuse together high-level operations while employing GPU-level optimizations for increasing compute efficiency and memory utilization.

SDPA is repeatedly evolving, with latest and improved backend implementations being introduced repeatedly. Staying up so far with the most recent PyTorch releases is vital to leveraging essentially the most recent performance improvements. For instance, PyTorch 2.5 introduced an updated CuDNN backend featuring a specialized SDPA primitive specifically tailored for training on NVIDIA Hopper architecture GPUs.

Within the code block below, we iterate through the list of supported backends and assess the runtime performance of coaching with every one. We use a helper function, set_sdpa_backend, for programming the SDPA backend:

from torch.nn.functional import scaled_dot_product_attention as sdpa

def set_sdpa_backend(backend):
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_cudnn_sdp(False)

if backend in ['flash_sdp','all']:
torch.backends.cuda.enable_flash_sdp(True)
if backend in ['mem_efficient_sdp','all']:
torch.backends.cuda.enable_mem_efficient_sdp(True)
if backend in ['math_sdp','all']:
torch.backends.cuda.enable_math_sdp(True)
if backend in ['cudnn_sdp','all']:
torch.backends.cuda.enable_cudnn_sdp(True)

for backend in ['flash_sdp', 'mem_efficient_sdp',
'math_sdp', 'cudnn_sdp']:
set_sdpa_backend(backend)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=sdpa)

print(f'PyTorch SDPA - {backend}')
train(block_fn)
print(f'Compiled PyTorch SDPA - {backend}')
train_compile(block_fn)

We summarize our interim leads to the table below

Step times for various attention functions (lower is best) — by Creator

While the alternative of SDPA backend has a noticeable impact on performance when running in eager mode, the optimizations performed by model compilation appear to overshadow the differences between the eye kernels. Once more, we caution against deriving any conclusions from these results because the performance impact of various attention functions can vary significantly depending on the precise model and use case.

While PyTorch SDPA is a fantastic place to begin, using third-party attention kernels may help speed up your ML workloads further. These alternatives often include added flexibility, offering a wider range of configuration options for attention. Some can also include optimizations tailored for specific hardware accelerators or newer GPU architectures.

On this section, we’ll explore a number of the third-party attention kernels available and evaluate their potential impact on runtime performance.

FlashAttention-3

While Pytorch SDPA supports a FlashAttention backend, more advanced FlashAttention implementations might be present in the flash-attn library. Here we’ll explore the FlashAttention-3 beta release which boasts a speed of as much as 2x in comparison with FlashAttention-2. Given the early stage in its development, FlashAttention-3 can only be installed directly from the GitHub repository and its use is proscribed to certain head dimensions. Moreover, it doesn’t yet support model compilation. In the next code block, we configure our transformer block to make use of flash-attn-3 while setting the eye input format to “bshd” (batch, sequence, head, depth) to fulfill the expectations of the library.

# flash attention 3
from flash_attn_interface import flash_attn_func as fa3
attn_fn = lambda q,k,v: fa3(q,k,v)[0]
block_fn = functools.partial(MyAttentionBlock,
attn_fn=attn_fn,
format='bshd')

print(f'Flash Attention 3')
train(block_fn)

The resultant step time was 240 ms, making it 5% faster than the SDPA flash-attn.

Transformer Engine

Transformer Engine (TE) is a specialized library designed to speed up Transformer models on NVIDIA GPUs. TE is updated repeatedly with optimizations that leverage the capabilities of the most recent NVIDIA hardware and software offerings, giving users access to specialized kernels long before they’re integrated into general-purpose frameworks corresponding to PyTorch.

Within the code block below we use DotProductAttention from TE version 1.11.0. Much like PyTorch SDPA, TE supports various backends that are controlled via environment variables. Here we reveal the usage of the NVTE_FUSED_ATTN backend.

def set_te_backend(backend):
# should be applied before first use of
# transformer_engine.pytorch.attention
os.environ["NVTE_FLASH_ATTN"] = '0'
os.environ["NVTE_FUSED_ATTN"] = '0'
os.environ["NVTE_UNFUSED_ATTN"] = '0'
if backend == 'flash':
os.environ["NVTE_FLASH_ATTN"] = '1'
if backend == 'fused':
os.environ["NVTE_FUSED_ATTN"] = '1'
if backend == 'unfused':
os.environ["NVTE_UNFUSED_ATTN"] = '1'

from transformer_engine.pytorch.attention import DotProductAttention
set_te_backend('fused')
attn_fn = DotProductAttention(NUM_HEADS, HEAD_DIM, NUM_HEADS,
qkv_format='bshd',
# disable masking (default is causal mask)
attn_mask_type='no_mask')

block_fn = functools.partial(MyAttentionBlock,
attn_fn=attn_fn,
format='bshd')

print(f'Transformer Engine Attention')
train(block_fn)
print(f'Compiled Transformer Engine Attention')
train_compile(block_fn)

TE attention resulted in average step times of 243 ms and 204 ms for the eager and compiled model variants, correspondingly.

XFormer Attention

Underlying the memory-efficient backend of PyTorch SDPA is an attention kernel provided by the xFormers library. Once more, we will go to the source to profit from the most recent kernel optimizations and from the total set of API capabilities. In the next code block we use the memory_efficient_attention operator from xFormers version 0.0.28.

# xformer memory efficient attention
from xformers.ops import memory_efficient_attention as mea
block_fn = functools.partial(MyAttentionBlock,
attn_fn=mea,
format='bshd')

print(f'xFormer Attention ')
train(block_fn)
print(f'Compiled xFormer Attention ')
train_compile(block_fn)

This eager model variant resulted in a median step time of 246 ms, making it 10.5% faster than the SDPA memory efficient kernel. The compiled variant resulted in a step time of 203 ms.

Results

The table below summarizes our experiments:

Step times for various attention functions (lower is best) — by Creator

The winner for the eager model was flash-attn-3 with a median step time that’s 54% faster than our baseline model. This translates to an identical 54% reduction in training costs. In compiled mode, the performance across the optimized kernels was kind of equal, with the fastest implementations achieving 202 ms, representing a 20% improvement in comparison with the baseline experiment.

As mentioned above, the precise impact savings is greatly depending on the model definition. To evaluate this variability, we reran the experiments using modified settings that increased the eye sequence length to 3136 tokens.

IMG_SIZE = 224
BATCH_SIZE = 8

# Define ViT settings
NUM_HEADS = 12
HEAD_DIM = 64
DEPTH = 6
PATCH_SIZE = 4
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 3136

The outcomes are summarized within the table below:

Results for giant seqlen (lower is best) — by Creator

Our immediate commentary is that when the sequence length is larger the performance impact of the eye kernels is way more pronounced. Once more, flash-attn-3 got here out in front for the eager execution mode — this time with a ~5x increase in performance in comparison with the PyTorch-native function. For the compiled model we see that the TE kernel broke away from the pack with an overall best step-time of 53 ms.

Up to now, we’ve focused on the usual attention function. Nevertheless, sometimes we will want to use a variant of the standard attention computation during which we either mask out a number of the values of intermediate tensors or apply some operation on them. A majority of these changes may interfere with our ability to make use of the optimized attention blocks we covered above. On this section we discuss a number of the ways to handle this:

Leverage Advanced Kernel APIs
Many optimized attention kernels provide extensive APIs with controls for customizing the eye computation. Before implementing a brand new solution, explore these APIs to find out in the event that they already support your required functionality.

Implement a custom kernel:
If the present APIs don’t meet your needs, you possibly can consider creating your individual custom attention implementation. In previous posts (e.g., here) we discussed a number of the pros and cons of custom kernel development. Achieving optimal performance might be extremely difficult. If you happen to do go down this path, one approach may be to begin with an existing (optimal) kernel and apply minimal changes to integrate the specified change.

Use FlexAttention:
A recent addition to PyTorch, FlexAttention empowers users to implement a wide selection of attention variants while not having to compromise on performance. Denoting the results of the dot product of the query and key tokens by rating, flex_attention allows for programming either a score_mod function or a block_mask mask that’s mechanically applied to the rating tensor. See the documentation in addition to the accompanying attention-gym repository for examples of the varieties of operations that the API enables.

FlexAttention works by compiling the score_mod operator into the eye operator, thereby making a single fused kernel. It also leverages the sparsity of block_masks to avoid unnecessary computations. The benchmarks reported within the FlexAttention documentation show considerable performance gains for a wide range of use cases.

Let’s see each the score_mod and block_mask in motion.

Rating Mod Example — Soft-Capping with Tanh

Soft-capping is a standard technique used to manage the logit sizes (e.g., see here). The next code block extends our PyTorch-native attention kernel with soft-capping:

def softcap_attn(q, k, v):
scale = HEAD_DIM ** -0.5
q = q * scale
attn = q @ k.transpose(-2, -1)
# apply soft-capping
attn = 30 * torch.tanh(attn/30)
attn = attn.softmax(dim=-1)
x = attn @ v
return x

Within the code block below we train our model, first with our PyTorch-native kernel, after which with the optimized Flex Attention API. These experiments were run with the 3136-length sequence settings.

# flex attention imports
from torch.nn.attention.flex_attention import (
create_block_mask,
create_mask,
flex_attention
)
compiled_flex = torch.compile(flex_attention)

# score_mod definition
def tanh_softcap(rating, b, h, q_idx, kv_idx):
return 30 * torch.tanh(rating/30)

block_fn = functools.partial(MyAttentionBlock, attn_fn=softcap_attn)

print(f'Attention with Softcap')
train(block_fn)
print(f'Compiled Attention with Softcap')
train_compile(block_fn)

flex_fn = functools.partial(flex_attention, score_mod=tanh_softcap)
compiled_flex_fn = functools.partial(compiled_flex, score_mod=tanh_softcap)

block_fn = functools.partial(MyAttentionBlock,
attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
attn_fn=compiled_flex_fn)

print(f'Flex Attention with Softcap')
train(compiled_block_fn)
print(f'Compiled Flex Attention with Softcap')
train_compile(block_fn)

The outcomes of the experiments are captured within the table below:

Soft-cap step time results (lower is best) — by Creator

The impact of the Flash Attention kernel is clearly evident, delivering performance boosts of roughly 3.5x in eager mode and 1.5x in compiled mode.

Mask Mod Example — Neighborhood Masking

We assess the mask_mod functionality by applying a sparse mask to our attention rating. Recall that every token in our sequence represents a patch in our 2D input image. We modify our kernel in order that each token only attends to other tokens that our inside a 5×5 window within the corresponding 2-D token array.

# convert the token id to a second index
def seq_indx_to_2d(idx):
n_row_patches = IMG_SIZE // PATCH_SIZE
r_ind = idx // n_row_patches
c_ind = idx % n_row_patches
return r_ind, c_ind

# only attend to tokens in a 5x5 surrounding window in our 2D token array
def mask_mod(b, h, q_idx, kv_idx):
q_r, q_c = seq_indx_to_2d(q_idx)
kv_r, kv_c = seq_indx_to_2d(kv_idx)
return torch.logical_and(torch.abs(q_r-kv_r)<5, torch.abs(q_c-kv_c)<5)

As a baseline for our experiment, we use PyTorch SDPA which incorporates support for passing in an attention mask. The next block includes the masked SDPA experiment followed by the Flex Attention implementation:

# materialize the mask to make use of in SDPA
mask = create_mask(mask_mod, 1, 1, SEQ_LEN, SEQ_LEN, device='cuda')

set_sdpa_backend('all')
masked_sdpa = functools.partial(sdpa, attn_mask=mask)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=masked_sdpa)
print(f'Masked SDPA Attention')
train(block_fn)
print(f'Compiled Masked SDPA Attention')
train_compile(block_fn)

block_mask = create_block_mask(mask_mod, None, None, SEQ_LEN, SEQ_LEN)
flex_fn = functools.partial(flex_attention, block_mask=block_mask)
compiled_flex_fn = functools.partial(compiled_flex, block_mask=block_mask)

block_fn = functools.partial(MyAttentionBlock,
attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
attn_fn=compiled_flex_fn)

print(f'Masked Flex Attention')
train(compiled_block_fn)
print(f'Compiled Masked Flex Attention')
train_compile(block_fn)

The outcomes of the experiments are captured below:

Masked attention step time results (lower is best) — by Creator

Once more, Flex Attention offers a substantial performance boost, amounting to 2.19x in eager mode and a pair of.59x in compiled mode.

Flex Attention Limitations

Although we’ve got succeeded in demonstrating the ability and potential of Flex Attention, there are just a few limitations that must be noted:

  1. Limited Scope of Modifications: With Flex Attention you’ll be able to (as of the time of this writing) only modify the eye rating (the results of the dot product between the query and key tokens). It doesn’t support changes at other stages of the eye computation.
  2. Dependency on torcch.compile: Given the reliance on torch.compile, great care should be taken to avoid excessive recompilations which could greatly degrade runtime performance. As an illustration, while the support for Document Masking very compelling, it should only perform as expected if the sum of the lengths of the entire documents stays fixed.
  3. No Support for Trainable Parameters in score_mod: On the time of this writing, Flex Attention doesn’t support a score_mod implementation that features trainable parameters. For instance, while the documentation highlights support for relative position encodings, these are commonly implemented with trainable parameters (reasonably than fixed values) which cannot currently be accommodated.

Within the face of those limitations, we will return to one among the opposite optimization opportunities discussed above.

Because the reliance on transformer architectures and a focus layers in ML models increases, so does the necessity for tools and techniques for optimizing these components. On this post, we’ve got explored various attention kernel variants, each with its own unique properties, capabilities, and limitations. Importantly, one size doesn’t fit all — different models and use cases will warrant the usage of different kernels and different optimization strategies. This underscores the importance of getting a wide selection tools and techniques for optimizing attention layers.

In a future post, we hope to further explore attention layer optimization by specializing in applying a number of the tools we discussed to tackle the challenge of handling variable-sized input sequences. Stay tuned…

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x