Greater than a billion per day: that’s a low estimate of what number of requests ChatGPT handles every day, a number which is unlikely to go down soon. For every request and every generated token, we run an inference of a multi-billion parameters model. For this reason model optimization is paramount at every level: when one deals with these sorts of scale, even a 1% latency or power gain can bring huge savings.
But where might that gain come from? Model architectures are already well established, and popular models have had quantized weight for a very long time now. Nonetheless, an important level at which we will optimize model inference stays: the kernel level. Kernels are the algorithms executed while you do any operation in your network: there are matrix multiplication kernels, convolution kernels, batch normalization kernels, etc. Kernels are low-level, highly-optimized algorithms, often tailored for the device they can be running on. They’re notoriously long and hard to jot down, and require an excellent understanding of the inner working of the GPU.
Kernels are essential for running operations in neural networks—with no kernel, an operation effectively cannot be used. For this reason, recent innovations often launch with a “day 0” kernel, typically optimized just for the most recent Nvidia hardware. This approach excludes many other devices, particularly AMD GPUs, which, despite offering comparable or superior specs, are sometimes missed by kernel developers. Hugging Face collaborated with AMD to deliver state-of-the-art performance on AMD platforms and make it profit the open source community. As a part of this partnership, we decided with AMD to deal with delivering open-source optimized kernels to enhance the performance of serving Llama 3.1 405B in FP8 on a node of 8 MI300X using VLLM.
On this blog post, we’ll explore how we optimized performance for the MI300X and the way each kernel was individually fine-tuned. But first, let’s have a look at the performance gains achieved using our custom kernels. By combining the next three optimized kernels:
- Fused residual connection, RMS norm and FP8 conversion kernel
- Fused SwiGLU activation and FP8 conversion kernel
- Skinny GEMM kernel
we achieved significant speedups when running VLLM on a node powered by MI300X GPUs.
Measures were taken with input size 1 and output size 128 to mimic decoding regime. We measure decoding latency using the median over 30 iterations.
Those performance gains were measured in VLLM, but chances are you’ll also use the kernels individually, as described within the “How one can” section that follows.
How one can use these kernels
The hf-rocm-kernels repo
All kernels described previously can be found on the hf-rocm-kernels repository situated here.
In it, you will discover instructions on the best way to install the package, the source code for every kernels, their respective python bindings, various benchmarking scripts and a test suite. Using benchmarking scripts and a MI300X, chances are you’ll even reproduce from this blog post. To make sure same results for Torch or VLLM, you should utilize the identical container as we did.
It’s also possible to use the repo as a base to construct your personal kernels: it has instructions on the best way to bind a CUDA-style kernel to python and an easy sample kernel.
Chances are you’ll also have a have a look at branches under development for brand new kernels, like a compute-and-communicate kernel as described here.
Integration in VLLM
The kernels described will soon be integrated within the AMD fork of the VLLM project, but when you would like to have a have a look at how you may do something like that yourself, chances are you’ll try this branch and this document.
We’re first going to do a fast refresher on the architecture of the device we’re working on: the MI300X. Then, we’re going to check out the state of our model’s inference before optimizing it. This can allow us to discover bottlenecks and know which custom kernels we want to jot down. Then, we are going to take a have a look at each kernel we now have written, which is able to give us a possibility to explore how kernel optimization is conducted through many angles.
A fast introduction to the MI300X
Before we dive into optimizing GPU code, we want to understand how a GPU works. There are quite a lot of resources on the market that already do a terrific job of explaining the inner workings of your GPU, which I’ll link right here, here and here. We’re still going to run through the various levels of the GPU, as a fast refresher.
If you would like to skip the refresher and get directly into the small print of our custom kernels, click here!
Threads
The smallest unit of labor within the GPU is the thread. Any time any work is completed on a GPU, it’s because a thread executed an instruction. Instructions are basic operations like additions, multiplication, conversion from one data type to a different, or loads and stores. Each thread has its own memory, called registers (or VGPRs), which only it could actually access. A thread can have a maximum of 256 registers, each 32-bit wide. Below is represented a thread with access to its 256 VGPRs.
Threads, except when using load or store instructions, can only execute instructions on their very own registers. As an illustration, so as to add two vectors A and B together, each thread goes to 1) load in its registers a component from A and a pair of) one other from B, then 3) perform the addition and store the end in one other register, and at last 4) store the worth from that register in memory. That’s a complete of 4 instructions.
Warps
The subsequent unit of labor is a warp: each warp consists of 64 threads. Warps don’t have their very own memory, but they’re of interest to us because all threads in a warp must execute the identical instruction at the identical time. That is each a guarantee and a constraint.
Warps also allow for various threads to exchange information coming from their registers with other threads in the identical warp. Although different threads in a warp have access to different data, the proven fact that all of them must execute the identical instructions signifies that when writing a kernel, warp-level behavior is what you might want to take into consideration.
Compute units
Warps are bundled together into thread blocks: thread blocks are software abstractions, but run on a hardware component called a compute unit (CU). A single compute unit can run multiple thread blocks without delay, but it could actually only fit 16 warps.
Each compute unit has a dedicated L1 cache and shared memory. L1 cache can’t be controlled or allocated and helps with data reuse of all warps situated on the CU. Conversely, shared memory could be allocated and used as a storage shared by all warps. As an illustration, when we would like all warps (and thus threads) in a compute unit to access the identical buffer, we allocate it in shared memory. Each shared memory and L1 cache are fast to access because they’re “close” to the threads.
Thread blocks also offer the power to synchronize all threads running inside: this is sort of useful when coping with operations that impact shared memory, like initializing an array in shared memory to zero or reduction operations. Basically, when writing a kernel, thread blocks are the best level to think about: it’s very hard to synchronize different thread blocks or make them interact in any way in any way.
Kernel throughput is tightly linked to the variety of compute unit present on the GPU: the more CUs there are, the more thread blocks could be run at the identical time, which increases throughput for those who manage to make use of all CUs.
XCDs
Compute units are then grouped into accelerator complex dies (XCDs), which hold 38 compute units each. Although CUs may not interact with each others, all of them share a L2 cache which you’ll be able to’t control but still may prove useful when re-using data. As an illustration, when accessing memory, having two compute units situated on the identical XCD access the identical data will reduce loading latency by quite a bit. L2 cache is sort of large: it has a size of 4MB, while shared memory has a size of 64kB and L1 cache comprises 32kB.
Your entire GPU (MI300X)
By assembling 8 XCDs (which supplies us 8 * 38 = 304 CUs) and adding a final level of cache (called infinity cache, with 256MB) and an enormous quantity of video ram (192GB) we get the MI300X.
All XCDs, and thus all threads, have access to the VRAM, but getting there is sort of slow. As you get further away from thread-level, memory becomes slower to access but has a bigger size and bigger scope, meaning it serves more threads. When optimizing a kernel, there’s at all times a balance to strike between doing numerous operations or loading numerous data, but on the whole, you would like to access the VRAM (commonly known as global memory) as little as possible.
When this figure, we will see why GPUs are known as “massively parallel”: here, we now have 304 compute units, which might each run 16 warps, each with 64 threads. Which means that we will have as much as 311296 threads running at the identical time, each executing an instruction of its own.
Consider an instruction is something basic like an addition, so easy routines like Newton’s method could be quite long to run for a single thread. GPUs will not be optimized for instructions to run fast, i.e. for the latency of every instruction to be low: that may be a latency-oriented device. They’re optimized for a lot of threads to be run together, consuming and outputting a great quantity of knowledge: it’s a throughput-oriented device.
When optimizing a kernel for the GPU, we adapt in consequence: it is best to have an algorithm running just a few instructions on many threads without delay, than having it run many instructions on just a few threads. Hence calling algorithms running on GPUs “parallel”.
What can get in the way in which of such algorithms running in an optimized manner are three things: when there’s quite a lot of data to load (memory certain), when there are lots of operations to performs (compute certain) or when threads must work together (synchronization overhead).
Day 0 performance evaluation
When optimizing a workload, the very first thing to do before writing a single line of code is to profile the present state of the workload.
In our case, we’re going to profile the model inference in VLLM to get an idea of how much time each operation is taking over. This may help discover major bottlenecks and which kernels we will tackle first for optimum speedup. As an illustration, here is the breakdown for batch size 32:
We will see the various parts of the network through each slice:
- the “Attention*” slice, where we grouped RoPE, attention and KV cache kernels;
- the “Attention GEMMs”, that encompass two projections, QKV and Output;
- the “Communications”, which is made up of two all-reduce operations, one after the Attention block and one after the MLP block, that are there because we’re working in tensor parallel (TP8)
- the “MLP GEMMs”, that encompass the 2 projections made within the MLP, Gate / Up and Down;
- the “RMS norm” and “SwiGLU” slices, one for every kernel — note that the RMS norm kernel known as twice per block, once before the Attention and once before the MLP;
- the “Other” slice that regroups the kernels that we didn’t tag as part of a bigger category because their impact is minor.
Already we will see that the majority of the latency comes from GEMMs and communications, but in addition that spotlight and the operations surrounding it will not be a significant contributor to latency. This could come as a surprise, because quite a lot of papers deal with attention and reducing its cost, but plainly through a mix of KV caching and FlashAttention, which has already been optimized in VLLM, this part may not be a top priority.
Surprisingly, the 2 calls made to the “RMS norm” kernel are quite costly, so there may be a big profit to optimizing that kernel. Together with the SwiGLU kernel, they represent 15% of the entire latency, which isn’t negligible. All in all, working on those two kernels, plus trying to achieve a small speedup on GEMMs could also be our greatest plan of action. To ascertain that this performance breakdown isn’t a fluke, we will take a have a look at other batch sizes:
We will see the pattern that emerged for batch size 32 holds up for other batch sizes, albeit with the latency contribution of GEMMs and communications becoming greater because the batch size increases. Also, plainly batch size 32 is an outlier with regards to the latency of GEMMs: it’s probably since the GEMMs chosen when batch size is 32 have been manually tuned or because batch size 32 presents good memory alignment patterns, so GEMMs for batch size 32 are faster than for batch size 24 or 28.
Now that we now have identified some hot spots to optimize, let’s take a have a look at the primary kernel we wrote: the RMS norm kernel.
RMS norm kernel
In each decoder block, we now have two major parts: an attention block and an MLP block. Each begin with a residual connection between two inputs: the present hidden states and the residual . Each have the identical shape, which is rows (as many as there are tokens) and columns. After they’re added together, we apply a row-wise Root Mean Square (RMS) norm to and, because the model is in FP8, we quantize to FP8 using a scale . Simply fusing those three operations right into a single kernel can deliver a pleasant performance boost. Mathematically, the operations we now have to perform are the next:
where is a -sized weight vector.
Steps (1) and (3) are pretty basic. For step (1), we just have to position each thread to a distinct location within the tensor, load some elements of and , add them and store back . For step (3), each thread performs some scalar operations (addition, square root, division) and a conversion to FP8. All of this, each thread can do by itself: that is perfectly suited to the parallel nature of the GPU. The step to look at out for is (2): we want to sum over , which suggests either each thread goes to go to each of the columns, or we want to exchange data between threads. The greater is, the more data we’d must load for the primary option, so the less viable it becomes. We’re going to pick the second option: synchronize threads on the block level, and they’re going to exchange data using the shared memory. Each thread goes to build up a component of by itself after which we’re going to sum all of those parts across the thread block, which is what we call a discount. Since is computed across a whole row, we’re going to assign a thread block for every row.
When put next to out-of-the-box pytorch, the bare bones version of this kernel brings a few 10x speedup. But this isn’t enough: there are still many optimizations we will add on top of this.
Optimization: memory-related
When it comes to latency, some of the costly operation is accessing VRAM, also called global memory. Luckily, there are some easy-to-follow principles that may dramatically reduce the associated fee of loading data.
First, we will take a have a look at how much data a single thread can load in a single instruction: using the MI300X instruction guide, we see that the most important load we will make from global memory is 128 bits wide. Since we’re loading FP16 data, we’re going to load 128b / 16b = 8 elements per load. For fp32 elements, it will correspond to 4 elements per load.
Secondly, we make certain memory accesses are coalesced. Since each thread is an element of a warp, when one thread reaches a “load” instruction, all other threads within the warp do too. For efficiency’s sake, these “load” instructions are then bundled together across the warp. The warp then collectively fetches the information needed and every thread gets the information it requires. Maximum efficiency is reached when the warp fetches a single chunk of knowledge with none gap in it: that is what we call contiguous data. A problem arises when we want to load more data that could be loaded in a single “load” instruction, and is illustrated below.
On this hypothetical scenario, we now have two threads in the identical warp. They should collectively load 16 fp32 elements, without constraint on which thread loads which element. This can be a typical “reduction” situation.
Since a thread can only load 4 fp32 elements per instruction, we now have no less than two ways of reading the information, represented in scenario (a) and (b). To come to a decision which scenario is best, we want to take a look at this from warp perspective, not thread perspective.
In scenario (a), the primary load fetches elements 0,1,2,3,8,9,10,11 : we see that the information isn’t contiguous, because there’s a niche between elements 3 and eight. While in scenario (b), the primary load fetches elements 0,1,2,3,4,5,6,7 : we load contiguous data. Same goes for the second load. Thus scenario (b) is best. Although in scenario (a) we find yourself with 8 contiguous elements per thread, this doesn’t matter: what matters is whether or not or not the warp loads contiguous data. This matters because if the warp can only load 8 contiguous elements in a single cycle, then each load of scenario (a) is processed in two cycles, while in scenario (b), each load only needs the one cycle.
Third, we reduce the variety of stores: after we have a look at steps (1) and (3) we will see that there are only two stores needed: one for and one for . After step (1) we will already store and be done with that. But we still have to access the modified version of after step (2) is completed. To try this, we will store the modified version of in global memory and reload it after step (2) is completed and depend on cache hits when reloading it. Or, if is sufficiently small, we will store its modified version in shared memory: if is in FP16 and we only have one thread block per CU, then we will store 64KB / 2B = 32 * 1024 elements in shared memory per thread block. Within the case of Llama 405B, is the same as 16384, so that matches. Using shared memory provides a pleasant speedup over counting on cache hits, especially when many thread blocks are lively without delay: if the L1 cache isn’t sufficiently big to suit the entire of , then we now have to depend on L2 cache, which is shared by 38 CUs.
Aside from memory access, we may optimize computational efficiency, but we’re going to leave that for the subsequent kernel, as they can be similar in each cases.
Results
Once we apply the optimizations discussed above, we get the next results:
| Variety of rows | Torch (μs) | VLLM (μs) | Ours (μs) |
|---|---|---|---|
| 1 | 38.8998 | 5.5145 | 4.18138 |
| 2 | 43.2469 | 5.65645 | 4.36976 |
| 4 | 41.1304 | 5.6893 | 4.37628 |
| 8 | 43.8883 | 5.72275 | 4.39081 |
| 16 | 46.8876 | 5.85667 | 4.48165 |
| 32 | 55.2276 | 6.08502 | 4.72017 |
| 64 | 75.6086 | 6.4629 | 5.54214 |
| 128 | 98.1122 | 7.49166 | 6.27341 |
| 256 | 119.727 | 11.8812 | 10.739 |
| 512 | 195.782 | 23.1595 | 18.5549 |
| 1024 | 355.42 | 44.8143 | 34.7204 |
| 2048 | 671.513 | 81.2089 | 73.35 |
with a [X, 16384] shaped FP16 input tensor.
Essentially the most basic version of our kernel, known as “Pointwise”, has no memory-related optimization and already shows no less than a x4 speedup over torch. It’s less optimal than VLLM’s implementation of the kernel, but our “Vectorized” implementation beats each “Pointwise” and VLLM. That is the version of the kernel that implements coalesced 128 bits loads, which is simply surpassed by the “Vectorized + SMEM” (SMEM stands for shared memory) implementation, that gives a notably higher speedup ratio than VLLM for each high and low batch sizes.
SwiGLU kernel
Within the MLP block, after the kernel we now have just written about, comes a projection which we now have referred up up to now as “Gate / Up” projection. The explanation we call it that way is since the “Gate / Up” projection is definitely a concatenation of two projections with the identical input: “Gate” and “Up”. Thus, we are going to write the result of the “Gate / Up” projection as where is the concatenation operator applied along the column axis. and have the identical dimensions. The explanation we want those two projections is the SwiGLU activation function that comes right after, which ends is defined by equation (4).
The SwiGLU activation function is followed by the “Down” projection, which in our case is in FP8, so we also have to quantize as shown in equation (5):
end{aligned} end{align}
where is the sigmoid function: . We’re going to jot down a fused kernel that takes care of all of this. For this kernel, optimizations described for the RMS kernel are still relevant with the expection of the shared memory buffer. We are going to focus here on computation-related optimizations.
Optimization: compute-related
There are two ways we’re going to increase the speed of our kernels: increase the quantity of labor done for every instruction executed and use faster instructions.
To extend the quantity of labor done per instruction, we will use packed instructions. Packed instruction are useful when we would like to use the identical operator on several elements: somewhat than executing one instruction per element, we execute one instruction over a vector of element. In a CPU, packed (or vectorized) instructions are the bread-and-butter of single-threaded optimization, because the AVX family of instruction can attest to. There are just a few packed instructions on GPU, but they could be quite useful in the proper place.
On the MI300X there’s, amongst others, packed instruction for FP16 addition and multiplication, which we are going to use for each steps. There also exists packed conversion from FP32 to FP8, which might provide a pleasant boost in performance compared to non-packed conversion. As a matter of fact, there isn’t any conversion from some other data type than FP32 to FP8, so for the RMS norm kernel and this one, we now have to go to FP32 precision in an effort to convert to FP8.
Nonetheless this isn’t a difficulty on this kernel: the sigmoid function require us to compute an exponent, which is an operation that greatly advantages from FP32 precision. And that is in an instance where we will optimize computation through the use of a faster instruction: as a substitute of using the exp instruction, we scale the input by and use the exp2 instruction, which is far faster. We suffer an almost negligible loss in precision but in addition reduce latency.
Results
We get the next table for a [X, 16384] shaped FP16 input tensor:
| Variety of rows | 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128 | 256 | 512 | 1024 | 2048 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Torch (μs) | 40.2731 | 29.923 | 35.305 | 23.5763 | 22.4738 | 25.3445 | 31.5829 | 40.3194 | 53.5369 | 79.8037 | 124.873 | 243.202 |
| VLLM (μs) | 3.84116 | 3.86192 | 3.92937 | 3.94151 | 4.01047 | 4.02421 | 4.08943 | 4.20317 | 4.48755 | 7.48465 | 13.7389 | 25.4306 |
| Ours (μs) | 1.92981 | 1.93904 | 1.93524 | 1.99316 | 2.00415 | 1.91563 | 2.04498 | 2.61763 | 3.57726 | 5.47608 | 10.0482 | 19.8957 |
| Speedup (VLLM / Ours) | 1.990434291 | 1.991665979 | 2.030430334 | 1.977518112 | 2.001082753 | 2.100724044 | 1.999740829 | 1.605715857 | 1.254465708 | 1.366789747 | 1.367299616 | 1.278195791 |
With memory and compute optimizations tailored for the MI300X, we get a kernel that’s greater than 14 times faster than Torch on average and from 27% to 100% faster than VLLM’s kernel.
Skinny GEMM kernel
As we now have seen earlier, about 60% of the model’s inference latency comes from projections, which depend on GEMM kernels. GEMM kernels are heavily optimized in dedicated libraries equivalent to hipBLASLT rocBLAS on AMD, so writing a custom kernel that performs higher in all cases is sort of hard. But when we deal with some edge cases which are relevant to us, and write a GEMM kernel for those specific cases, then there’s a likelihood our custom kernel could also be faster than those within the dedicated libraries.
In each prefill and decoding, the input of any of the network’s projection has as many rows as there tokens being processed. And through decoding, the variety of tokens being processed is the same as the batch size. So during decoding, the variety of input rows of all GEMM kernels is the same as the batch size, which for our purposes ranges between 1 and 256. We’re going to take an interest with very low batch sizes.
When we now have a GEMM such that has few rows and lots of columns, we are saying that the GEMM is skinny. The explanation we now have a selected term for such GEMMs is that they’re ill-fitted for the classic GEMM algorithm we run on GPU. Often, the efficiency of GEMM kernels comes from tiling: we divide the result matrix in lots of sub-matrices, called tiles, and we assign each tile to a distinct compute unit (CU). If we now have many tiles, we will use many CUs and GPU usage is high. That is illustrated within the figure below.
But when the input has only a few rows, then only just a few tiles could be formed, which ends up in only just a few compute units lively, hence low GPU utilization:
Skinny GEMMs are fundamentally inconvenient for the GPU. In the subsequent part, we’re going to see how through a custom kernel that assumes we’re in a thin GEMM context, we will make them more convenient.
Optimization: split-K
For the reason that major issue of thin GEMMs is that we use too few compute units, the very first thing we will do is work out a strategy to use more. To do that, we will exploit the next mind-breaking formula:
Due to the associativity of the sum, we will split the major GEMM along the shared axis (commonly known as the K axis) and replace one GEMM with several sub-GEMMs which are executed concurrently. Each sub-GEMM goes to make use of as many CUs because the major one would have, so the variety of CUs used is multiplied by the variety of times we split the K axis. That is shown within the figure below:
Here, we set split K equal to 2 and thus double the quantity of CU used without delay. Since we get partial results, we want so as to add them up after the each sub-GEMMs are done. What could appear counter-intuitive is that we’re adding an operation, summing the partial results, yet we claim to cut back the latency of the general process. But since each CU must undergo all the K axis to compute the result, because we’re cutting it in two, the quantity of labor done by each CU can also be cut in two. If the quantity of labor saved this manner counter balances the quantity of labor added by the summing up of the ultimate results, then we now have an overall optimization. This is mostly true so long as K is large and the unique GEMM uses lower than 50% of the GPU.
Optimization: removing padding
If we assume that through split-K, most compute units are busy with their very own tile, we will focus the scope of optimization on the compute unit level. We’re going to check out how the actual matrix multiplication is completed, and the way we will speed up it.
In cutting-edge GPUs just like the MI300X, matrix multiplication is handled by a dedicated hardware unit called tensor cores. Tensor cores only perform matrix multiplications, but they achieve this at very high speed.
The format of tensor core instruction is mfma_MxNxK... where mfma stands for matrix fused multiply-add, M is the variety of rows of the left-hand matrix, N the variety of column of the right-hand matrix, and K is the shared dimension of each. We illustrate a hypothetical instruction mfma_2x2x4 below:
There are only just a few tensor core instructions, but for any triplet MxNxK using the dedicated tensor core instruction is far faster than some other alternative.
Tensor core instruction also are available in two flavours: “dense” and “sparse”. Dense instruction correspond to straightforward matrix multiplication. Sparse instructions assume that the left-hand side matrix has a 4:2 structured sparsity pattern, which suggests that two out of each 4 elements along the matrix K axis are zero. Mathematically, for any such that is a component of , we now have no less than two zeros in . Below is an example of a sparse matrix.
Let’s get back to our model, Llama 405B in FP8. For FP8, we only have two dense tensor core instruction: 16x16x32 and 32x32x16 . We even have one sparse instruction of size 16x16x64 .
For an input with 8 rows, using even the smallest dense instruction 16x16x32 signifies that we now have so as to add 8 rows of padding to our input, which is a waste of compute resources. One can wonder if we will use the sparse instruction as a substitute: in spite of everything, if half of a 16 rows matrix is 4:2 sparse, we will fully describe its non-zero coefficients using a dense 8 rows matrix. Conversely, if we now have an 8 rows dense matrix, we will fit all of its data right into a 16 rows matrix with 4:2 sparsity. And the good thing about using the sparse instruction is clear: the dense instruction has K=32 while the sparse instruction has K=64 . For a similar amount of cycles, the sparse instruction has twice the depth. We illustrate this sparsity trick within the figure below with a 1 row input and the 2x2x4 dense instruction and its sparse 2x2x8 counterpart.
Using this trick, we will notably speed up our GEMM for any input with 8 or less rows, which ends up in a discount in per-token latency for any decoding batch that has lower than 8 requests.
Optimization: warp specialization and asynchronous execution
We now have seen that in a thin GEMM, the very fact we now have a bit of variety of rows limits the variety of output tiles, which in turns limit the GPU utilization. However the small variety of rows also limits the variety of rows each output tiles has, which in turns reduces what we call arithmetic intensity. Simply put, arithmetic intensity is the quantity of labor done divided by the quantity of knowledge loaded to do this work. Allow us to compare two examples:
where is an -sized vector and is a scalar.
To compute , we load elements and perform additions. To compute , we load 1 element and perform additions and multiplications. So the “arithmetic intensity” of computing is while is : the computation of is more “arithmetically intensive” than the computation of . What we see here is that when the lower arithmetic intensity is, the more data we want to load to perform work.
Why does this matter to us? Well, we now have seen that loading data from VRAM has a high latency cost, which isn’t great for the GPU. In other words, workloads with low arithmetic intensity are ill-suited for the GPU, and it seems skinny GEMMs have lower arithmetic intensity than their non-skinny counterparts. This becomes intuitive when the figure below: we will see that after we divide the quantity of knowledge loaded by two, we divide the variety of output coefficients by 4, as a result of the quadratic nature of the GEMM’s dimensions.
In a thin GEMM the variety of rows of the output tile is restricted and so is the arithmetic intensity. Already because of this we’re going to have to load quite a lot of data to compute an output tile. Moreover, since we’re using FP8 arithmetic, computation is sort of fast, so we cannot depend on computation time to cover the latency of knowledge loading. All in all, it will be ideal to have more threads in command of loading data than threads in command of computing the result.
To attain this, we’re going to use a method called warp specialization. As a substitute having all warps within the thread block execute the identical instructions, we’re going to dedicate some warps to loading data only and a few to computing the outcomes only. The warps in command of loading data are called producers and those that compute the outcomes are named consumers. Producers and consumers work asynchronously: producers first load data from the VRAM, which is slow, and make it available to the consumers by storing it in a shared memory buffer. Until data is accessible in shared memory, the patron is idle. After it data is made available, the patron loads it from shared memory, which is fast, and computes the result.
Coordination of producers and consumers is achieved through a queue stored in shared memory. When a producer finishes storing data in a shared memory buffer , it changes the state of the th variable of the queue to signal data is accessible there. The patron is watching out for this, and begins loading data afterwards. When it is completed, it changes the th variable of the queue to signal that data could be written over in buffer .
Within the figure below, we represent the steps involved in an easy asynchronous GEMM with one producer, one consumer and a queue of size 2.
What makes the entire process work is that after buffer is filled by a producer, it could actually start working on buffer without waiting for the patron to have loaded the information from buffer . The goal is to have a queue large enough for the producers to be consistently filling buffers and consumers consistently consuming them. The dimensions of queue is constrained by the dimensions of the shared memory.
We also have to tune the ratio of producers to consumers: we now have said that we now have a low arithmetic intensity, so we want to load quite a lot of data to do a comparatively fast computation. Hence, we’re going to have quite a lot of producer warps (typically 8 or 10) for just a few consumer warps (something like 2 or 3). Moreover, we will exploit the very fact the GEMM is skinny by having separate producers for the input (the thin matrix) and the weights (the non-skinny matrix). To make the output tile larger within the dimension by which it isn’t constrained in, which is the columns dimension, we allocate more producers for the weights.
For a more in-depth blog post about asynchronous GEMMs, I encourage you to examine out this blog post. Plenty of its contents will not be applicable in our case though: the MI300X has no warp-level barriers, only a single thread block-level barrier. This led to “fun” shenanigans like ASM to make sure warps waited at their barriers, shared memory loads and stores were resolved before checking the barrier state, and careful handling of the modular nature of the queue. All this could be misplaced here, but I encourage you to examine out the code or ask away within the comments. A deep dive on the small print of async handling may be coming in the longer term.
Through warp specialization and asynchronous work, we will adapt our kernel to the low arithmetic intensity workload, but is that enough to come back ahead of libraries like hipBLASLT? The reply is yes, in some cases.
Results
Since Torch already binds a highly optimized GEMM taken from AMD’s linear algebra library, we will not be going to get speedups in the identical range as for the 2 last kernels.
We’re first going to check out the three GEMM dimension which are of interest to us: namely, the GEMMs dimensions related to the QKV projection, the Gate / Up projection and the Down projection. The Output projection is being ignored because its dimensions don’t correspond to the thin GEMM case.
| M (rows) | N (cols) | K (depth) | Torch time (μs) | SkG time (μs) | Speedup |
|---|---|---|---|---|---|
| 1 | 2304 | 16384 | 14.938 ± 0.292 | 11.685 ± 0.299 | 127.84 % |
| 8 | 2304 | 16384 | 16.300 ± 0.282 | 12.342 ± 0.375 | 132.07 % |
| 16 | 2304 | 16384 | 16.693 ± 0.233 | 13.909 ± 0.295 | 120.02 % |
| 32 | 2304 | 16384 | 16.817 ± 0.124 | 17.021 ± 0.133 | 98.80 % |
| 1 | 13312 | 16384 | 77.636 ± 0.364 | 54.717 ± 0.628 | 141.88 % |
| 8 | 13312 | 16384 | 80.031 ± 0.449 | 58.355 ± 0.612 | 137.15 % |
| 16 | 13312 | 16384 | 75.236 ± 0.378 | 59.973 ± 1.922 | 125.45 % |
| 32 | 13312 | 16384 | 82.198 ± 0.590 | 69.483 ± 1.672 | 118.30 % |
| 1 | 16384 | 6656 | 31.066 ± 0.193 | 27.613 ± 0.218 | 112.51 % |
| 8 | 16384 | 6656 | 31.559 ± 0.200 | 28.134 ± 0.209 | 112.17 % |
| 16 | 16384 | 6656 | 31.671 ± 0.250 | 30.233 ± 0.267 | 104.76 % |
| 32 | 16384 | 6656 | 35.561 ± 0.335 | 35.052 ± 1.365 | 101.45 % |
Measures are taken after 500 warmups iterations, over 2000 profiling iterations, using CUDA graph and multiple weights to avoid cache hits.
So as, the GEMM dimensions shown above correspond to QKV projection (N = 2304 and K = 16384), Gate / Up projection (N = 13312 and K = 16384) and Down projection (N= 16384 and K = 6656). We will see that for those dimensions, which have been tuned for, there’s a notable speedup for low variety of rows (M = 1, 8, 16) but less so for more rows (M = 32). Especially for dimensions by which we will use our sparsity trick (M = 1, 8) we see a notable speedup over Torch, which probably pads every part to 16 rows to make use of the smallest MFMA instruction.
Conclusion
On this post, we explored only a handful of the various kernel optimization techniques available. When you’re fascinated by experimenting with them, be at liberty to dive into the hf-rocm-kernels repository and begin tinkering! And for those who develop a kernel you want of and wish to distribute it, you’ll want to try kernel-builder and kernels — two Hugging Face packages designed to assist kernel builders make their work widely available and more impactful.



















