Learning Triton One Kernel at a Time: Matrix Multiplication

-

multiplication is undoubtedly probably the most common operation performed by GPUs. It’s the elemental constructing block of linear algebra and shows up across a large spectrum of various fields equivalent to graphics, physics simulations and scientific computing while being ubiquitous in machine learning.

In today’s article, we’ll break down the conceptual implementation of general matrix-matrix multiplication (GEMM) while introducing several optimisation concepts equivalent to tiling and memory coalescing. Finally, we’ll implement GEMM in Triton!

Naive GEMM

Let’s start easy: we would like to multiply two matrices X and Y with shapes (M,N) and (N,K) respectively. The output matrix Z=X@Y will due to this fact have shape (M,K).

This operation involves computing the dot products of all pairs of rows and columns in X and Y respectively. A simple NumPy implementation might look something like this:

While easy to write down, read and understand, this implementation is very inefficient by way of memory access and caching. As mentioned in the primary article of this series, a fundamental aspect of GPU optimisation is minimising data transfers

Nonetheless, our current implementation starts by loading a row from X, iteratively loads all K columns of Y, computes their dot product and repeats the method for each row in X. This leads to a complete of M(K+1) loading operations. 

Naive Matrix Multiplication, purple and blue tiles represent the vectors involved in dot products at each time step and green cells the computed output values.

As seen within the animation, the memory access pattern is wasteful, as every column of Y is loaded M times. As an analogy: that is like running to the food market (global memory) each time you would like a brand new ingredient for a dish as a substitute of preparing all of the ingredients in your kitchen counter (shared memory). Ideally, we would love to minimise the variety of times each chunk of knowledge is loaded and maximise its reusability once loaded. This leaves us with two predominant axes of optimisation:

  1. How can we improve the access pattern to minimise redundant loads?
  2. How much data can we load without delay, and where should or not it’s stored on the GPU?

Tiled GEMM

As mentioned previously, the naive approach to GEMM leads to many redundant loads, which induces unnecessary overhead. Ideally, we’d prefer to load each segment of knowledge just once and perform all of the operations wherein they’re used before dropping them from memory.

A chic approach to this problem is tiling, which involves dividing large matrices in smaller or sub-matrices. Consider two matrices X and Y with shapes (4,6) and (6,4) respectively, X@Y leads to a matrix Z with shape (4,4)

With a purpose to compute the primary element of Z, Z[0,0], we’d like to compute the dot product between the primary row of X and the primary column of Y: Z[0,0] = dot(X[0, :], Y[:, 0]). We may break down the dot product into smaller chunks, for example in groups of three elements: Z[0,0] = dot(X[0,0:3], Y[0:3, 0]) + dot(X[0,3:6], Y[3:6, 0])

Alternatively, we are able to expand this approach to 2 dimensions and compute a complete (2,2) block of Z at a time: Z[0:2, 0:2] = dot(X[0:2, 0:2], Y[0:2, 0:2]) + dot(X[0:2, 2:4], Y[2:4, 0:2]) + dot(X[0:2, 4:6], Y[4:6, 0:2])

Here’s a visible representation of tiled matrix multiplication:

Tiled Matrix Multiplication. The computation is split in several “tiles” of X and Y (highlighted in pale blue and purple), each containing several blocks (dark blue and purple). In each block, we compute dot products (green cells in X and Y). These dot products are gathered across the blocks of a tile to compute the output values in Z (the buildup is represented by colours from orange to green).

The above animation illustrates how data is reused in tiled GEMM. For every 2×2 block in X and Y, we compute 4 dot products, which ends up in a (2,2) output matrix in Z. Since each tile incorporates 3 blocks, we’d like to build up 3 of those matrices to compute the ultimate (2,2) output in Z. This accumulation is represented by coloured cells in Z

Within the kitchen analogy, that is like fetching ingredients from the shop and preparing them on the kitchen counter (i.e. small shared memory), reusing them several times before going back to the shop.

Importantly, reusing loaded data over multiple steps allows this approach to drastically reduce the variety of load operations. For (2,2) blocks, each X row and Y column is utilized in two dot products. Due to this fact, we’re performing twice as many operations with each block of loaded data, roughly halving the variety of load operations! Note that this generalises to larger blocks as well, using a (32,32) block would scale back the variety of loads by an element of around 32. 

Now you’re probably wondering “how large can these blocks be”? To reply this query, let’s recall how memory is managed in modern GPUs.

GPU Memory Hierarchy

We distinguish 4 predominant forms of memory in Nvidia GPUs. Here, we take the instance of an A100:

  • Registers: The fastest and smallest form of memory on the GPU, residing directly inside each Streaming Multiprocessor (SM). On the A100, each SM provides 256 KB of register file space (65,536 × 32-bit registers), distributed amongst its threads. Each thread gets its own private 32-bit registers for storing temporary variables and intermediate results, avoiding memory traffic altogether. Nonetheless, register usage per thread directly affects occupancy, as using too many registers per thread limits what number of threads can run concurrently.
  • L1/Shared Memory: On an A100, each SM has 192KB of SRAM that will be flexibly configured as either a hardware-managed L1 cache or a programmer-managed shared memory. For performance-critical kernels like matrix multiplication, we explicitly use this space as shared memory to stage data tiles near the compute units, bypassing the L1 cache entirely. This provides us fine-grained control over data reuse.
  • L2 cache: This cache is slower than L1 but much larger, with around 40 MB shared across all SMs on the A100. It serves as a worldwide cache for each data and directions, reducing the variety of accesses to high-latency HBM memory. The L2 cache is coherent across SMs, meaning that updates from one SM are visible to others, enabling synchronisation between thread blocks. Its bandwidth can reach several terabytes per second, acting as a buffer between the fast on-chip SRAM and the slower HBM.
  • High Bandwidth Memory (HBM): That is the device memory, it has a capability of either 40GB or 80GB depending on the A100 model. It provides extremely high bandwidth (as much as 2 TB/s on the 80 GB variant) but with much higher latency than on-chip caches. HBM is where large tensors, model weights, and datasets reside during execution. Since accessing HBM is pricey, efficient kernels aim to minimise data movement and maximise on-chip data reuse via registers and shared memory.

As you’ll be able to see, the memory hierarchy generally trades off capability with latency. Due to this fact, maximising performance boils right down to loading data from HBM into shared memory efficiently and reusing it as much as possible.

GPU Memory Hierarchy, from fastest/smallest (top) to slowest/largest (bottom).

Selecting our block size is critical. We wish blocks to be large enough to create a variety of parallel work, but sufficiently small that their data suits within the SM’s shared memory and registers. A BLOCK_SIZE of 64 is a standard start line since it’s a multiple of the warp size (32 threads), ensuring full hardware utilisation.

Parallel Tiled GEMM

With these considerations in mind, a natural follow-up to our tiled GEMM is to parallelise the computation of every pairs of tiles over several thread blocks, as depicted on the next animation.

Parallel Tiled Matrix Multiplication. The iteration over tiles is replaced by a parallel operation over multiple thread blocks.

Memory Coalescing

Before writing tiled GEMM in Triton, we’d like to contemplate one last detail: memory coalescing, a method that enables optimal use of world memory bandwidth. Memory coalescing is achieved when subsequent threads in a warp access subsequent memory addresses. Imagine a librarian needing to fetch books for a client, if all books are side-by-side on a shelf, they’ll grab them unexpectedly. In contrast, if all books are lying on different shelves, they’ll must grab them one after the other, which takes significantly longer.

To know how this is applicable to our case, note that matrices are stored linearly in memory, in other words a (2,2) matrix is stored as a sequence of 4 consecutive elements. Frameworks like PyTorch adopt a row-major layout, meaning that elements of a matrix are per-row contiguous in memory. As an illustration, elements of our (2,2) matrix can be stored as follows: [(0,0), (0,1), (1,0), (1,1)], notice that elements of the identical row are (touching) while elements of the identical column have a of 1 (separated by one element).

PyTorch stores matrices in row-major layout. Elements of a row contiguous in memory while elements of a column are strided.

This suggests that we are able to load rows using coalesced loads, but columns do not satisfy this condition. Nonetheless, we’d like to access columns of Y to compute dot products. With a purpose to maximise performance, a very good practice is to transpose Y in order that we iterate on its rows somewhat than its columns. 

Nonetheless, transposing Y isn’t enough to switch its layout in memory. As mentioned previously, PyTorch stores matrices in a flat array. Each matrix dimension is related to a stride attribute, denoting the jump needed to go from one element to the following one along this dimension. As an illustration, a (10,10) matrix would have strides=(10,1). Indeed, ranging from element [0,0], element [1,0] is 10 memory slots (i.e. one row) away, whereas element [0,1] is adjoining. 

When transposing a tensor, PyTorch doesn’t modify the layout in memory but simply recomputes the strides. With a purpose to make the transpose effective from a memory standpoint we’d like to call Y.T.contiguous().

These are the required steps the load columns of Y efficiently, nonetheless we’ll have to transpose the loaded blocks inside the kernel to perform the dot product properly: z_block = tl.dot(X_block, Y_block.T).

Representation of Y, Y.T and Y.T.contiguous() of their block representation and memory layout. The transpose operation changes the behaviour of the matrix but doesn’t modify its memory layout. That is why we’d like so as to add .contiguous() to enable coalesced reads on rows.

Triton Implementation

From here on, we first describe the kernel without memory coalescing to simplify the logic and pointer arithmetic before summarising the changes required to make the load operations coalesced on Y columns.

Let’s start by specializing in the PyTorch wrapper across the kernel. We’d like to read M, N, K from the input matrices and compute their strides since these constants will probably be useful later within the kernel. Then, we define the BLOCK_SIZE and declare the grid.

Now let’s dive into the actual kernel code. We’re going to utilize Triton’s make_block_ptr utility, which simplifies the pointer arithmetic. We create one block pointer per matrix and pass the matrix shape, its strides, and the scale of the block as inputs. Moreover, we specify the offset, the coordinate of the top-left element in the present block. For X, this corresponds to (m_idx * BLOCK_SIZE, 0) where m_idx is the index of the present block along the M dimension. 

From there, we define z_acc, a zero matrix that may receive the partial dot-products as we iterate through tiles. We now iterate through the shared dimension N, loading blocks of size (BLOCK_SIZE, BLOCK_SIZE), and accumulate their dot products in z_acc. We then move the block pointers along the shared dimension through the use of .advance.

You may have noticed that when loading data, we use boundary_check and padding_option as a substitute of mask and other as within the previous article. These arguments are specific to the usage of block pointers and specify which axes to ascertain for out-of-bound operations (here (0,1) for x and y) and how one can treat those invalid values. Here we set them to zero to be ignored within the dot product.

We will now take a have a look at the performance of this kernel through the use of the next function:

def bench(fn: callable, x: torch.Tensor, y: torch.Tensor, repeat: int):
  flops = []
  med_latency = []

  for _ in tqdm(range(repeat), desc=f"Benchmarking {fn.__name__}"):
    latency_ms = triton.testing.do_bench(
      lambda: fn(x, y),
      quantiles=[0.5], # get the median latency
      return_mode="all",
      )
    n_flops = 2 * M * N * K # matmul roughly requires 2*M*N*K operations
    tflops = n_flops / (latency_ms / 1e3) / 1e12

    med_latency.append(latency_ms)
    flops.append(tflops)

  flops = np.array(flops)
  med_latency = np.array(med_latency)
  print(f"Absolute Error: {torch.sum(torch.abs(X@Y - fn(x, y)))}")
  print(f"Median Latency: {med_latency.mean():.4f} ± {med_latency.std():.3f} ms")
  print(f"Throughput: {flops.mean():.4f} ± {flops.std():.3f} TeraFLOPS")

M = 8192
N = 6144
K = 4096

X = torch.randn((M, N), device="cuda", dtype=torch.float32)
Y = torch.randn((N, K), device="cuda", dtype=torch.float32)

bench(block_matmul, X, Y, repeat=10)

We get the next outputs (using a T4 GPU on Colab):

Absolute Error: 0.0 # the kernel outputs the right result!
Median Latency: 130.7831 ± 1.794 ms
Throughput: 3.1533 ± 0.043 TeraFLOPS

Now let’s review the changes required for coalesced loads on Y: we mainly have to flip the form, strides and offsets when defining the block pointer for Y. Moreover, we update the block pointer to maneuver along the column dimension (previously row dimension). The complete code for this implementation is on the market on GitHub.

@triton.jit
def coalesced_block_matmul_kernel(
    X_ptr, X_m_stride, X_n_stride,
    Y_ptr, Y_k_stride, Y_n_stride,
    Z_ptr, Z_m_stride, Z_k_stride,
    M, N, K,
    BLOCK_SIZE: tl.constexpr,
):
    ... 
    y_block_ptr = tl.make_block_ptr(
        base=Y_ptr,
        # flip the form, strides and offsets to match Y.T
        shape=(K, N),
        strides=(Y_k_stride, Y_n_stride), 
        offsets=(k_idx * BLOCK_SIZE, 0),
        block_shape=(BLOCK_SIZE, BLOCK_SIZE),
        order=(0, 1),
    )
    ...

    for _ in range(0, N, BLOCK_SIZE):
        ... # loads
        z_acc += tl.dot(x, y.T)  # transpose Y back for dot product
        x_block_ptr = tl.advance(x_block_ptr, offsets=(0, BLOCK_SIZE))
        # advance the block pointer along columns of Y.T (i.e rows of Y)
        y_block_ptr = tl.advance(y_block_ptr, offsets=(0, BLOCK_SIZE))

    tl.store(pointer=z_block_ptr, value=z_acc, boundary_check=(0, 1))

def coalesced_block_matmul(X, Y):
    Y = Y.T.contiguous()  # Y is now (K,N)
    M, N = X.shape
    K, _ = Y.shape
    Z = torch.empty((M, K), device="cuda")

    x_stride_m, x_stride_n = X.stride()
    y_stride_k, y_stride_n = Y.stride()
    z_stride_m, z_stride_k = Z.stride()

    ...  # define BLOCK_SIZE and grid

    coalesced_block_matmul_kernel[grid](
        X, x_stride_m, x_stride_n,
        Y, y_stride_n, y_stride_k,
        Z, z_stride_m, z_stride_k,
        M, N, K,
        BLOCK_SIZE,
    )

    return Z

Listed here are the outcomes of our benchmark for the kernel with coalesced loads for Y:

Absolute Error: 0.0 # Again, the kernel is correct!
Median Latency: 261.9420 ± 0.858 ms
Throughput: 1.5741 ± 0.005 TeraFLOPS

Surprisingly, the throughput of this second kernel is barely half of what we obtained with the primary one, despite improving the efficiency of load operations 🤔

A fast inspection using nsight (Nvidia’s kernel profiler, more on that in a future article) reveals that the transpose operation inside the kernel creates a “traffic jam”. Specifically, the transpose creates bank conflicts, causing threads to stay idle more often than not. Notably, the warp scheduler has no eligible warp to dispatch 87.6% of the time as they’re waiting for the bank conflict to resolve. Moreover, the report reads:

———————– ———– ————–
Metric Name Metric Unit Metric Value
———————– ———– ————–

DRAM Throughput % 8.20
Compute (SM) Throughput % 21.14

This means that the kernel is latency sure (i.e. neither memory nor compute sure, confer with the previous article for more details). In contrast, the primary kernel is compute sure (i.e. increasing compute will improve performance) for the reason that compute throughput is high in comparison with the DRAM throughput.

———————– ———– ————–
Metric Name Metric Unit Metric Value
———————– ———– ————–

DRAM Throughput % 29.35
Compute (SM) Throughput % 74.39

Conclusion

This experiment highlights the importance of profiling and empirical validation. Even well-intentioned optimisations like coalescing memory accesses can introduce recent bottlenecks if not evaluated fastidiously. The primary kernel, though simpler, was compute-bound and higher matched the hardware characteristics.

In the following articles of this series, we’ll implement a softmax kernel, paying particular attention to integrating Triton with PyTorch’s autograd and profiling kernels using Nsight.

Until next time! 👋

Useful Resources

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x