, slightly optimisation goes a great distance. Models like GPT4 cost greater than $100 tens of millions to coach, which makes a 1% efficiency gain price. A robust strategy to optimise the efficiency of machine learning models is by writing a few of their components directly on the GPU. Now for those who’re anything like me, the easy mention of CUDA kernels is sufficient to send chills down your spine, as they’re notoriously complex to put in writing and debug.
Fortunately, OpenAI released Triton in 2021, a brand new language and compiler abstracting away much of CUDA’s complexity and allowing less experienced practitioners to put in writing performant kernels. A notable example is Unsloth, an LLM-training service that guarantees 30x faster training with 60% less memory usage, all because of replacing layers written in PyTorch with Triton kernels.
On this tutorial series, we’ll learn the fundamentals of GPU architecture and learn how to implement high-performance Triton kernels! All of the code presented on this series can be available at https://github.com/RPegoud/Triton-Kernels.
GPU Architecture Basics
On this section, we’ll undergo the very basics of () GPUs to get us began and write our first Triton kernel by the top of this text.
Ranging from the smallest software unit, we will describe the hierarchy of execution units as follows:
- Threads: The smallest unit of labor, they run the user-defined kernel code.
- Warps: The smallest scheduling unit, they’re all the time composed of 32 parallel threads, each with their very own instruction address counter and register state. Threads in a warp start together but are free to branch and execute independently.
- Thread Blocks: Group of warps, where all threads can cooperate via shared memory and sync barriers. It’s required that thread blocks can execute independently and in any order, in parallel or sequentially. This independence allows thread blocks to be scheduled in any order across any variety of cores, in order that GPU programs scale efficiently with the variety of cores. We will synchronise the threads inside a block at specific points within the kernel if needed, for instance to synchronise memory access.
- Streaming Multiprocessor (SM): A unit in command of executing many warps in parallel, it owns shared memory and an L1 cache (holds essentially the most recent global-memory lines that the SM has accessed). An SM has a dedicated warp scheduler that pull warps from the thread blocks which are able to run.
On the hardware side, the smallest unit of labor is a CUDA core, the physical Arithmetic Logic Unit (ALU) which performs arithmetic operations for a thread (or parts of it).
To summarise this section with an analogy, we could see CUDA cores as individual staff, while a warp is a squad of 32 staff given the identical instruction without delay. They could or may not execute this task the identical way (branching) and may potentially complete it at a special cut-off date (independence). A thread block consists of several squads sharing a standard workspace (i.e. have shared memory), staff from all squads within the workspace can wait for one another to get lunch at the identical time. A streaming multiprocessor is a factory floor with many squads working together and sharing tools and storage. Finally, the GPU is a whole plant, with many floors.
Optimisation Basics
When optimising deep learning models, we’re juggling with three foremost components:
- Compute: Time spent by the GPU computing floating point operations (FLOPS).
- Memory: Time spent transferring tensors inside a GPU.
- Overhead: All other operations (Python interpreter, PyTorch dispatch, …).
Keeping those components in mind helps determining the correct strategy to resolve a bottleneck. For example, increasing compute (e.g. using a more powerful GPU) doesn’t help if more often than not is spent doing memory transfers. Ideally though, more often than not ought to be spent on compute, more precisely on matrix multiplications, the precise operation GPUs are optimised for.
This suggests minimising the associated fee paid to maneuver data around, either from the CPU to the GPU (”data transfer cost”), from one node to the opposite (”network cost”) or from CUDA global memory (DRAM, low-cost but slow) to CUDA shared memory (SRAM, expensive but fastest on-device memory). The later is named bandwidth costs and goes to be our foremost focus for now. Common strategies to scale back bandwidth costs include:
- Reusing data loaded in shared memory for multiple steps. A main example of that is tiled matrix multiplication, which we’ll cover in a future post.
- Fusing multiple operations in a single kernel (since every kernel launch implies moving data from DRAM to SRAM), as an illustration we will fuse a matrix multiplication with an activation function. Generally, operator fusion can provide massive performance increase because it prevents numerous global memory reads/writes and any two operators present a chance for fusion.

In this instance, we perform a matrix multiplication x@W
and store the lead to an intermediate variable a
. We then apply a relu
to a
and store the lead to a variable y
. This requires the GPU to read from x
and W
in global memory, write the lead to a
, read from a
again and eventually write in y
. As an alternative, operator fusion would allow us to halve the quantity of reads and writes to global memory by performing the matrix multiplication and applying the ReLU in a single kernel.

Triton
We’ll now write our first Triton kernel, a straightforward vector addition. First, let’s walk through how this operation is broken down and executed on a GPU.
Consider wanting to sum the entries of two vectors X
and Y
, each with 7 elements (n_elements=7
).
We’ll instruct the GPU to tackle this problem in chunks of three elements at a time (BLOCK_SIZE=3
). Subsequently, to cover all 7 elements of the input vectors, the GPU will launch 3 parallel “programs”, independent instance of our kernel, each with a novel program ID, pid
:
- Program 0 is assigned elements
0, 1, 2
. - Program 1 is assigned elements
3, 4, 5
. - Program 2 is assigned element
6
.
Then, these programs will write back the ends in a vector Z
stored in global memory.
A crucial detail is that a kernel doesn’t receive a complete vector X
, as a substitute it receives a pointer to the memory address of the primary element, X[0]
. With the intention to access the actual values of X
, we’d like to load them from global memory manually.
We will access the information for every block by utilizing this system ID: block_start = pid * BLOCK_SIZE
. From there, we will get the remaining element addresses for that block by computing offsets = block_start + range(0, BLOCK_SIZE)
and cargo them into memory.
Nonetheless, keep in mind that program 2 is just assigned element 6, but its offsets are [6, 7, 8]
. To avoid any indexing error, Triton lets us define a mask to discover valid goal elements, here mask = offsets < n_elements
.
We will now safely load X
and Y
and add them together before writing the result back to an output variable Z
in global memory in the same way.

Let’s take a more in-depth take a look at the code, here’s the Triton kernel:
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # pointer to the primary memory entry of x
y_ptr, # pointer to the primary memory entry of y
output_ptr, # pointer to the primary memory entry of the output
n_elements, # dimension of x and y
BLOCK_SIZE: tl.constexpr, # size of a single block
):
# --- Compute offsets and mask ---
pid = tl.program_id(axis=0) # block index
block_start = pid * BLOCK_SIZE # start index for current block
offsets = block_start + tl.arange(0, BLOCK_SIZE) # index range
mask = offsets < n_elements # mask out-of-bound elements
# --- Load variables from global memory ---
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# --- Operation ---
output = x + y
# --- Save results to global memory ---
tl.store(pointer=output_ptr + offsets, value=output, mask=mask)
Let’s break down a few of the Triton-specific syntax:
- First, a Triton kernel is all the time decorated by
@triton.jit
. - Second, some arguments must be declared as static, meaning that they're known at compute-time. That is required for
BLOCK_SIZE
and is achieved by add thetl.constexpr
type annotation. Also note that we don't annotate other variables, since they are usually not proper Python variables. - We use
tl.program_id
to access the ID of the present block,tl.arange
behaves similarly to Numpy’snp.arange
. - Loading and storing variables is achieved by calling
tl.load
andtl.store
with arrays of pointers. Notice that there isn't areturn
statement, this role is delegated totl.store
.
To make use of our kernel, we now need to put in writing a PyTorch-level wrapper that gives memory pointers and defines a kernel grid. Generally, the kernel grid is a 1D, 2D or 3D tuple containing the variety of thread blocks allocated to the kernel along each axis. In our previous example, we used a 1D grid of three thread blocks: grid = (3, )
.
To handle various array sizes, we default to grid = (ceil(n_elements / BLOCK_SIZE), )
.
def add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
"""PyTorch wrapper for `add_kernel`."""
output = torch.zeros_like(x) # allocate memory for the output
n_elements = output.numel() # dimension of X and Y
# cdiv = ceil div, computes the variety of blocks to make use of
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
# calling the kernel will mechanically store `BLOCK_SIZE` in `meta`
# and update `output`
add_kernel[grid](X, Y, output, n_elements, BLOCK_SIZE=1024)
return output
Listed here are two final notes in regards to the wrapper:
You would possibly have noticed that grid
is defined as a lambda function. This enables Triton to compute the variety of thread blocks to launch at launch time. Subsequently, we compute the grid size based on the block size which is stored in meta
, a dictionary of compile-time constants which are exposed to the kernel.
When calling the kernel, the worth of output
can be modified in-place, so we don’t must reassign output = add_kernel[…]
.
We will conclude this tutorial by verifying that our kernel works properly:
x, y = torch.randn((2, 2048), device="cuda")
print(add(x, y))
>> tensor([ 1.8022, 0.6780, 2.8261, ..., 1.5445, 0.2563, -0.1846], device='cuda:0')
abs_difference = torch.abs((x + y) - add(x, y))
print(f"Max absolute difference: {torch.max(abs_difference)}")
>> Max absolute difference: 0.0
That’s it for this introduction, in following posts we’ll learn to implement more interesting kernels reminiscent of tiled matrix multiplication and see learn how to integrate Triton kernels in PyTorch models using autograd
.
Until next time! 👋