Learning Triton One Kernel at a Time: Softmax

-

Within the previous article of this series, operation in all fields of computer science: matrix multiplication. It’s heavily utilized in neural networks to compute the activation of linear layers. Nevertheless, activations on their very own are difficult to interpret, since their values and statistics (mean, variance, min-max amplitude) can vary wildly from layer to layer. That is one in every of the explanation why we use activation functions, for instance the logistic function (aka sigmoid) which projects any real number within the [0; 1] range.

The softmax function, also often called the normalised exponential function, is a multi-dimensional generalisation of the sigmoid. It converts a vector of raw scores (logits) right into a probability distribution over M classes. We are able to interpret it as a weighted average that behaves as a smooth function and might be conveniently differentiated. It is an important component of dot-product attention, language modeling, and multinomial logistic regression.

In this text, we’ll cover:

  1. Implementing an efficient softmax kernel in Triton.
  2. Implementing the backward pass (autograd).
  3. Optimisation: cache modifiers and auto-tuning.

In case you aren’t conversant in Triton yet, consult with the previous articles!

Definition

The softmax is defined as follows:

The normalisation ensures that the vector sums to 1, in order that it could actually be interpreted as a legitimate probability distribution.

Note that this formulation of the softmax is extremely sensitive to numerical overflow. Recall that the utmost value a typical float16 can represent is 65 504, which is roughly exp(11). Which means that any input value greater than ~11 will end in exp(z_i) exceeding the representable range, resulting in overflow.

A standard trick to mitigate this issue is to subtract the utmost value of the input vector from every element, such that the brand new maximum is 0 before exponentiation and 1 after.

Naive Implementation

As you may see, computing the softmax involves two reduction operations, a max and a sum. A naive algorithm require three separate passes over the input vector. First to compute the utmost, then the sum, and eventually the normalised outputs.

Here’s what a naive Numpy implementation looks like:

A recurrent theme on this Triton series is minimising high-latency global memory access. Our current Numpy implementation requires three separate memory reads of the complete input vector, which is extremely inefficient.

Online Softmax

Fortunately, we will use a clever trick, often called the online softmax, to fuse the max and sum steps, reducing the variety of memory reads to 2.

First, we define the sum of exponentials recursively. In the next set of equalities, m_i refers back to the maximum over x until the -th index.

This equality allows us to compute the sum of exponentials iteratively using the utmost value up to now. We are able to leverage it to fuse the primary and second loop within the naive implementation and compute the utmost and sum of exponentials iteratively.

Our algorithm becomes:

This is definitely translated to Numpy:

Now that we understand the essential principles behind the softmax, we’ll implement it in Triton, starting by the easy, single-block version and increase to the web, multi-block formulation. In the long run, we would like our kernel to behave like a PyTorch module and be compatible with autograd.

Unfortunately, from PyTorch’s perspective, Triton kernels behave like black boxes: the operations they perform will not be traced by autograd. This requires us to implement the backward pass ourselves and explicitly specify how gradients must be computed. Let’s brush up on our beloved chain rule and derive the softmax gradient.

Gradient

Because the outputs of the softmax are strictly positive, we will use the logarithmic derivative to make the derivation of the gradient easier. Here, we take the derivative of the log of the output and apply the chain rule:

From there, we rearrange the terms and follow these steps:

Now assume that we have now some upstream gradient, for instance generated by a loss function (e.g. a cross-entropy loss). We get the next expression of the gradient:

The simplification of the left term in (9) is as a consequence of the incontrovertible fact that δ_ij will only be equal to 1 for the -th element, collapsing the sum over j to a single term.

Triton Implementation

Single Block Softmax

Now that we worked through the derivation of the gradient, we will write the forward and backward softmax kernels. First, let’s give attention to the PyTorch wrapper to grasp how the only block implementation works at a high level. Given a 2D input tensor, the forward and backward kernels are going to process all rows in parallel.

For simplicity, we’ll define the BLOCK_SIZE to be large enough to handle all columns directly. Specifically, we’ll set it as the subsequent power of two superior to the variety of columns, as required by Triton.

Then, we’ll define our `grid` to be the variety of rows (it could potentially also handle a batch dimension).

The PyTorch wrapper for our SoftmaxSingleBlock is a category inheriting from torch.autograd.Function that implements forward and backward. Each methods take a ctx argument, which we’ll use to cache the softmax outputs through the forward pass and reuse them through the backward pass.

Each kernels are pretty straightforward, we start by loading the row inputs using the identical syntax as in my previous vector addition article. Notice that BLOCK_SIZE and num_warps are computed using a calculate_settings function. This function comes from the Unsloth library and was reused in other kernel libraries reminiscent of LigerKernel (which the kernels in this text are loosely based on), it provides heuristics to tune each variables:

def calculate_settings(n: int) -> tuple[int, int]:
 MAX_FUSED_SIZE = 65536 # maximum grid dimension on Nvidia GPUs
    BLOCK_SIZE = next_power_of_2(n)
    if BLOCK_SIZE > MAX_FUSED_SIZE:
        # we remove this assertion in this text
        raise RuntimeError(
            f"Cannot launch Triton kernel since n = {n} exceeds "
            f"the utmost CUDA blocksize = {MAX_FUSED_SIZE}."
        )
    num_warps = 4
    if BLOCK_SIZE >= 32768:
        num_warps = 32
    elif BLOCK_SIZE >= 8192:
        num_warps = 16
    elif BLOCK_SIZE >= 2048:
        num_warps = 8
    return BLOCK_SIZE, num_warps

Then, we implement the regular softmax for the forward pass and equation (10) for the backward pass. The one novelty here in comparison with previous articles is using cache modifiers, which tell the compiler how you can cache and evict data. For now, we’ll only give attention to three cache modifiers:

  • .ca (Cache in any respect levels): Tells the compiler to load the information in each L1 and L2 cache, suggesting that it is likely to be reused soon. This modifier must be used when the information is sufficiently small to suit into L1 (~128–192KB per SM on an A100) and can likely be accessed repeatedly.
  • .cs (Streaming): Treat data as streaming, it’ll be used once after which discarded to unencumber space in L1.
  • .wb (Write-back): Normal cached write, the information will remain within the cache hierarchy, good if the output could also be reused.

In the next kernels, we’ll use the .ca modifier for loads since we perform multiple operations on the loaded data. For storing, we’ll use .cs within the forward pass, because the outputs won’t be immediately reused and .wb within the backward pass since within the context of autograd (i.e. the chain rule), gradient outputs can be consumed by downstream kernels.

Multi-Block Softmax

Now, let’s take a have a look at the web formulation of the softmax. On this section, we implement a multi-block variant of the previous kernel. This version will use BLOCK_SIZE < n_cols, in other words, we’ll only load a tile with BLOCK_SIZE elements at a time, much like how we handled tiled GEMM within the last tutorial. Now you would possibly ask “how will we select the block size?”. 

That is a terrific occasion to introduce Triton’s autotune utility. Supplied with an inventory of configuration, autotune will perform a grid-search to find out and cache the very best configuration for a selected input shape. This process is repeated each time a recent input shape is passed to the kernel.

Here, we perform a grid search over the block size and variety of warps using the next utility function:

from itertools import product

# --- Multi Block Tuning ---
BLOCK_SIZES = [256, 512, 1024, 2048, 4096, 8192]
NUM_WARPS = [2, 4, 8, 16]

def get_autotune_config(
    block_sizes: list[int], num_warps: list[int]
) -> list[triton.Config]:
    return [
        triton.Config(kwargs={"BLOCK_SIZE": bs}, num_warps=nw)
        for (bs, nw) in list(product(block_sizes, num_warps))
    ]

We are able to now decorate our multi-block kernels with autotune and pass the list of configs, key=”n_cols” indicates that the optimal config relies on the variety of columns of the input.

The implementation of those kernels is conceptually very near the web softmax we covered before, the essential differences is that we iterate over tiles (not over single elements like in Numpy), which requires some adjustments. For example, we add a sum over the tile within the d update and the backward kernel now requires two iterations as well.

BLOCK_SIZEnum_warpsautotune

Testing and Benchmarking

We are able to now execute a forward and backward pass with each kernels and ensure they match the PyTorch baselines:

def validate_kernel(kernel_fn: callable) -> None:
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    torch.random.manual_seed(0)

    # Generate inputs
    x = torch.randn((256, 512), device=device) # triton input
    x.requires_grad = True
    xt = deepcopy(x) # torch input

    triton_output = kernel_fn(x)
    torch_output = torch.softmax(xt, dim=1)
    torch.testing.assert_close(triton_output, torch_output) # test fwd kernel

    # Setup fake labels
    y = torch.zeros_like(x)
    inds = (torch.arange(0, y.shape[0]), torch.randint(0, 3, (y.shape[0],)))
    y[inds] = 1

    # Define loss and run backward pass
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(torch_output, y)
    loss.backward()

    # Save gradient tensor for later
    torch_xgrad = xt.grad.detach().clone()
    triton_loss = loss_fn(triton_output, y)
    triton_loss.backward()
    torch.testing.assert_close(x.grad, torch_xgrad) # test grad outputs

validate_kernel(softmax_sb)
validate_kernel(softmax_mb)

Finally, we benchmark our implementation against the PyTorch baseline using the next snippet:

# --- Source: Triton softmax tutorial ---
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["N"],  # argument names to make use of as an x-axis for the plot
        x_vals=[
            128 * i for i in range(2, 100)
        ],  # different possible values for `x_name`
        line_arg="provider",  # argument name whose value corresponds to a special line within the plot
        line_vals=[
            "triton_single_block",
            "triton_multi_block",
            "torch",
        ],  # possible values for `line_arg``
        line_names=[
            "Triton_single_block",
            "Triton_multi_block",
            "Torch",
        ],  # label name for the lines
        styles=[("blue", "-"), ("green", "-"), ("red", "-")],
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={"M": 4096},  # values for function arguments not in `x_names` and `y_name`
    )
)
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == "torch":
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == "triton_single_block":
        torch.cuda.synchronize()
        ms = triton.testing.do_bench(lambda: softmax_sb(x))
        torch.cuda.synchronize()
    if provider == "triton_multi_block":
        torch.cuda.synchronize()
        ms = triton.testing.do_bench(lambda: softmax_mb(x))
        torch.cuda.synchronize()
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)

benchmark.run(show_plots=True, print_data=True)

Excellent news! Our single-block kernel consistently outperforms the PyTorch baseline while the multi-block variant falls off for inputs with greater than 6k columns:

Considering larger inputs, we will make several observations:

  1. The multi-block kernel eventually stabilises around 900GB/s of throughput, surpassing the PyTorch baseline for inputs with greater than 30k columns. 
  2. Interestingly, it looks like the multi-block variant will dominate for inputs with greater than 60k columns.
  3.  Regardless that we exceed the utmost block size with the single-block variant, the kernel still runs easily for some reason. Indeed, Triton robotically manages the block size under the hood. 
    When n_cols is larger than the hardware limit, Triton will break down the input and iterate over it. Nevertheless, this appears to be slower than the multi-block approach. 

To go further, we could mix each approaches in a single kernel that explicitly selects the optimal kernel based on the input size. This manner, we might profit from the high performance of the single-block kernel for small inputs and the upper throughput of the multi-block variant for inputs with greater than 60k columns.

This concludes the third episode of this Triton series, thanks again to your support!

In the subsequent article, we’ll leverage the web softmax formulation within the context of Flash Attention.

Until next time! 👋

Resources:

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x