Accelerating Long-Context Model Training in JAX and XLA

-


Large language models (LLMs) are rapidly expanding their context windows, with recent models supporting sequences of 128K tokens, 256K tokens, and beyond. Nevertheless, training these models with prolonged context lengths presents significant computational and communication challenges. As context lengths grow, the memory and communication overhead of attention mechanisms scale quadratically, creating bottlenecks that traditional parallelism strategies struggle to handle efficiently.

This post demonstrates that integrating the NVSHMEM communication library into Accelerated Linear Algebra (XLA) compiler optimizes context parallelism. This integration enables the efficient training of Llama 3 8B model in JAX framework with sequences as much as 256K tokens. Our results show that NVSHMEM provides as much as 36% speedup over NVIDIA Collective Communications Library (NCCL) for long-context training workloads, particularly when combined with tensor parallelism across multiple nodes.

The long-context training challenge

To grasp why NVSHMEM provides significant speedups for long-context training, it’s vital to first understand how context parallelism works and the unique communication patterns it creates. This section explains why the fine-grained, latency-sensitive communication of ring attention makes it a perfect candidate for optimization.

Context parallelism and ring attention

Context parallelism (CP) is a parallelization strategy designed specifically for handling long sequences in transformer models. Unlike data parallelism, which splits the batch, or tensor parallelism, which splits the model, context parallelism splits the sequence dimension across multiple devices.

Ring attention is a memory-efficient implementation of context parallelism that uses a ring-based communication pattern. During attention computation, each device:

  • Processes its local portion of the sequence
  • Exchanges Key Value (KV) tensors with neighboring devices in a hoop topology
  • Incrementally computes attention scores as KV blocks flow into across the ring

This approach reduces peak memory usage while maintaining mathematical equivalence to straightforward attention, making it possible to coach with sequences that may otherwise exceed GPU memory capability.

Communication patterns in ring attention

Ring attention involves frequent, fine-grained communication operations:

  • Point-to-point transfers: Sending KV tensors to the following device within the ring
  • Overlapped compute-communication: Computing attention on current KV blocks while fetching the following blocks
  • Low-latency requirement: KV transfers are on the critical path and must complete before attention can proceed

These characteristics make ring attention a perfect candidate for low-latency communication libraries like NVSHMEM.

GPU-optimized communication with NVSHMEM

NVSHMEM is a communication library that implements the OpenSHMEM parallel programming model for NVIDIA GPUs. It provides several key features that distinguish it from traditional communication libraries, including symmetric memory (SM), stream-aware communication, copy engine offloading, and more, as detailed below.

Symmetric memory

NVSHMEM provides a partitioned global address space (PGAS) resident in GPUs memories. Applications allocate buffers from this symmetric heap using nvshmem_malloc, and these pointers could be directly utilized in communication operations. For instance:

int32_t *src_d = (int32_t *)nvshmem_malloc(1024 * sizeof(int));
int32_t *dest_d = (int32_t *)nvshmem_malloc(1024 * sizeof(int));
ret = nvshmemx_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dest_d, src_d, 1024, 0);
Symmetric memory regions (shared) and private memory regions at each PE. The aggregation of the shared memory segments across all PEs is referred to as a partitioned global address space (PGAS).
Symmetric memory regions (shared) and private memory regions at each PE. The aggregation of the shared memory segments across all PEs is referred to as a partitioned global address space (PGAS).
Figure 1. Symmetric memory heap in NVSHMEM

Stream-aware communication

NVSHMEM provides peer-to-peer (P2P) on-stream APIs (equivalent to put_nbi_on_stream and signal_on_stream) to efficiently move data and supply low-latency synchronization over P2P-connected GPUs. 

One in every of the important thing benefits of those APIs over traditional host-initiated communication is their ability to perform these operations through a zero-SM footprint by leveraging the copy-engine (CE) and stream memory operations capabilities of GPU hardware. Among the underlying CUDA interfaces include:

  • Direct GPU-to-GPU transfers: Just like cudaMemcpyAsync, but with lower latency through optimized data paths
  • Positive-grained synchronization: Using cuStreamWriteValue32 and cuStreamWaitValue32 primitives for efficient signaling between devices without CPU involvement

Along with the P2P on-stream APIs, NVSHMEM also provides popular collective operations (reduce_on_stream, for instance)  commonly utilized in AI workloads equivalent to AllReduce. These collectives leverage SHARP, in-network reductions, and multicast acceleration features of NVIDIA NVLINK Switch to enable latency-optimized one-shot and throughput-optimized two-shot AllReduce algorithms. The underlying CUDA interface includes multimem ISA, providing additional advantages of a reduced-SM footprint as primitives equivalent to reductions and broadcast are offloaded to the switch.

Each of those features can exhibit useful compute-communication operations pipelining as most or all the GPU SMs can be found for compute operations, when overlapped in time on the identical CUDA stream.

CUDA Graphs interoperability

NVSHMEM operations could be captured into CUDA Graphs, enabling:

  • Amortized kernel launch overhead across multiple iterations
  • Optimized execution scheduling by the CUDA runtime
  • Seamless composition with other graph-captured operations

This composability is crucial for production training frameworks that depend on CUDA Graphs for performance optimization.

Integrating NVSHMEM and XLA

This section describes how NVSHMEM is integrated into the XLA compiler infrastructure, covering runtime flags, automatic backend selection heuristics, and the compilation flow.

Runtime control through debug options

XLA exposes a runtime flag for dynamic control:

XLA_FLAGS="--xla_gpu_experimental_enable_nvshmem=true"

This flag is defined in xla/debug_options_flags.cc and allows users to enable or disable NVSHMEM without recompilation (default value = false). The “experimental” prefix indicates that the API may evolve because the feature matures.

Automatic backend selection

The CollectiveBackendAssigner pass within the compilation pipeline determines which communication backend to make use of based on workload characteristics. That is where the intelligence of this method lies.

Selection heuristics

The compiler analyzes each collective operation and decides whether to make use of NVSHMEM based on three key criteria:

  1. Single device: Use NVSHMEM when just one device is visible per process (no network overhead)
  2. Single partition: Use NVSHMEM when all participating devices within the collective operation are managed by the identical process
  3. NVLink domain: Use NVSHMEM for intranode communication over NVIDIA NVLink fabric

Moreover, message size heuristics apply:

  • AllReduce operations: Only use NVSHMEM if message size < threshold (typically 16 MB). For larger messages, fall back to NCCL which is optimized for bandwidth.
  • CollectivePermute operations: All the time use NVSHMEM no matter message size (no threshold applied).
  • Rationale: AllReduce advantages from NCCL ring or tree algorithms for big messages, while CollectivePermute point-to-point nature makes NVSHMEM low latency ideal at any size.

JAX framework integration

The strength of this architecture lies in its complete transparency to Python frameworks. A JAX developer writes standard collective operations:

import jax
import jax.numpy as jnp

@jax.jit
def collective_permute_example(x):
    # Shift data from each device to the following device in a hoop
    axis_name = 'devices'
    perm = [(i, (i + 1) % jax.device_count()) for i in range(jax.device_count())]
    return jax.lax.ppermute(x, axis_name, perm=perm)

# The compiler routinely selects NVSHMEM when appropriate
result = collective_permute_example(data)

The XLA compiler analyzes this ppermute (collective permute) operation and routinely with the next steps:

  • Applies heuristics: single device, single partition, or inside NVLink domain
  • Recognizes a CollectivePermute operation (no message size threshold applies)
  • Selects NVSHMEM for optimal point-to-point communication
  • Generates thunks that invoke NVSHMEM host APIs at runtime
  • NVSHMEM host APIs enqueue operations on CUDA streams. For instance:  nvshmemx_float_sum_reduce_on_stream, nvshmemx_float_put_nbi_on_stream

This end-to-end integration signifies that high-level JAX code routinely advantages from NVSHMEM performance without requiring any user-level changes or annotations.

Experimental methodology

To guage NVSHMEM performance advantages, the team conducted experiments on Llama 3 8B across a variety of sequence lengths (64K to 256K tokens) and parallelism configurations. This section details the model setup, hardware configuration, and the metrics used to check NVSHMEM against the NCCL baseline.

Model configuration

The team evaluated NVSHMEM-accelerated context parallelism on the Llama 3 8B model with the next configurations.

  • Model: Llama 3 8B
  • Precision: BF16
  • Context parallel strategy: Ring attention
  • Framework: MaxText (JAX-based training framework)
  • Hardware: NVIDIA GB200 NVL72
  • Docker image: Available through NVIDIA/JAX-Toolbox
  • JAX version: JAX 0.6.2 and later 

Parallelism configurations

Various combos of parallelism strategies were tested across different sequence lengths (Table 1).

Sequence length Nodes GPUs Context parallelism Tensor parallelism Fully sharded data parallelism Sequence length per GPU after CP split
64K 1-4 4-16 4-16 1 1-2 4K-16K
128K 2-8 8-32 8-32 1 1-2 4K-16K
256K 8-16 32-64 16-32 2 1-2 8K-16K
Table 1. Parallelism configurations tested across different sequence lengths

Longer sequences (256K) employed tensor parallelism (TP=2) along with context parallelism to suit the model inside GPU memory constraints.

Communication backend comparison

Each configuration was evaluated with two communication backends:

  1. NCCL (baseline)
  2. NVSHMEM-enabled implementation

Measurements:

  • TFLOP/s per device: GPU computational throughput
  • Step time (seconds): Time per training iteration
  • Speedup: Relative performance improvement of NVSHMEM over NCCL

All metrics were averaged across iterations 3-20 (skipping the primary two warmup iterations) and computed from rank 0 logs to make sure consistency.

Performance results

As shown in Table 2, the NVSHMEM performance advantage grows significantly with sequence length:

  • 64K sequences: 0.3-3.9% speedup (modest improvement)
  • 128K sequences: 0.7-2.4% speedup (consistent improvement)
  • 256K sequences: 30.4-36.3% speedup (dramatic improvement)

This scaling behavior aligns with the ring attention communication pattern: longer sequences require more KV tensor exchanges across the ring, amplifying the advantages of the NVSHMEM lower-latency communication.

When scaling across nodes, internode communication latency becomes more critical. NVSHMEM nonblocking host APIs and optimized data paths provide consistent advantages across 8-16 node deployments.

Sequence length Nodes CP TP GPUs Seq/GPU Default TFLOP/s NVSHMEM TFLOP/s Speedup
64K 1 4 1 4 16K 605.64 607.36 +0.3%
64K 2 8 1 8 8K 549.92 557.17 +1.3%
64K 4 16 1 16 4K 482.19 501.06 +3.9%
128K 2 8 1 8 16K 512.22 515.87 +0.7%
128K 4 16 1 16 8K 473.58 472.46 -0.2%
128K 8 32 1 32 4K 420.99 431.13 +2.4%
256K 8 16 2 32 16K 366.94 500.22 +36.3%
256K 16 32 2 64 8K 346.33 451.70 +30.4%
Table 2. Performance comparison of default (NCCL) and NVSHMEM across different configurations

Practical implications

Based on these results, NVSHMEM provides clear benefits for:

  • Long-context training: Sequences ≥ 128K tokens where communication becomes a bottleneck
  • Multinode deployments: Scaling beyond single-node NVLink domains
  • Ring attention and similar patterns: Workloads with fine-grained, latency-sensitive communication
  • Hybrid parallelism: Configurations combining CP, TP, and FSDP

The XLA integration makes NVSHMEM accessible to JAX. No user code changes are required, simply use an NVSHMEM-enabled XLA construct and set the suitable environment flags.

Start with long-context model training 

Training LLMs with long-context windows requires efficient communication strategies that may handle fine-grained, latency-sensitive data exchanges. The mixing of NVSHMEM into XLA enables transparent acceleration of context parallelism with ring attention, providing as much as 36% speedup for 256K token sequences on Llama 3 8B.

Key takeaways:

  • The NVSHMEM nonblocking host APIs and low-latency data paths are ideally suited to the ring attention communication pattern
  • XLA compiler integration makes NVSHMEM accessible to high-level frameworks without requiring code changes
  • Performance advantages scale with sequence length, with dramatic improvements for sequences ≥ 256K tokens
  • Multinode deployments see the biggest gains, making NVSHMEM essential for production long-context training

As context windows proceed to grow, solutions optimizing low-latency communication like NVSHMEM can be crucial for making long-context training practical and cost-effective. We encourage the community to try NVSHMEM-enabled XLA builds in JAX framework and share their experiences with long-context workloads.

To start, try MaxText Framework, NVIDIA/JAX-Toolbox, and openxla/xla on GitHub.

Acknowledgments

We would love to specific our gratitude to NVSHMEM contributors Seth Howell and Akhil Langer.



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x