Large language models (LLMs) are rapidly expanding their context windows, with recent models supporting sequences of 128K tokens, 256K tokens, and beyond. Nevertheless, training these models with prolonged context lengths presents significant computational and communication challenges. As context lengths grow, the memory and communication overhead of attention mechanisms scale quadratically, creating bottlenecks that traditional parallelism strategies struggle to handle efficiently.
This post demonstrates that integrating the NVSHMEM communication library into Accelerated Linear Algebra (XLA) compiler optimizes context parallelism. This integration enables the efficient training of Llama 3 8B model in JAX framework with sequences as much as 256K tokens. Our results show that NVSHMEM provides as much as 36% speedup over NVIDIA Collective Communications Library (NCCL) for long-context training workloads, particularly when combined with tensor parallelism across multiple nodes.
The long-context training challenge
To grasp why NVSHMEM provides significant speedups for long-context training, it’s vital to first understand how context parallelism works and the unique communication patterns it creates. This section explains why the fine-grained, latency-sensitive communication of ring attention makes it a perfect candidate for optimization.
Context parallelism and ring attention
Context parallelism (CP) is a parallelization strategy designed specifically for handling long sequences in transformer models. Unlike data parallelism, which splits the batch, or tensor parallelism, which splits the model, context parallelism splits the sequence dimension across multiple devices.
Ring attention is a memory-efficient implementation of context parallelism that uses a ring-based communication pattern. During attention computation, each device:
- Processes its local portion of the sequence
- Exchanges Key Value (KV) tensors with neighboring devices in a hoop topology
- Incrementally computes attention scores as KV blocks flow into across the ring
This approach reduces peak memory usage while maintaining mathematical equivalence to straightforward attention, making it possible to coach with sequences that may otherwise exceed GPU memory capability.
Communication patterns in ring attention
Ring attention involves frequent, fine-grained communication operations:
- Point-to-point transfers: Sending KV tensors to the following device within the ring
- Overlapped compute-communication: Computing attention on current KV blocks while fetching the following blocks
- Low-latency requirement: KV transfers are on the critical path and must complete before attention can proceed
These characteristics make ring attention a perfect candidate for low-latency communication libraries like NVSHMEM.
GPU-optimized communication with NVSHMEM
NVSHMEM is a communication library that implements the OpenSHMEM parallel programming model for NVIDIA GPUs. It provides several key features that distinguish it from traditional communication libraries, including symmetric memory (SM), stream-aware communication, copy engine offloading, and more, as detailed below.
Symmetric memory
NVSHMEM provides a partitioned global address space (PGAS) resident in GPUs memories. Applications allocate buffers from this symmetric heap using nvshmem_malloc, and these pointers could be directly utilized in communication operations. For instance:
int32_t *src_d = (int32_t *)nvshmem_malloc(1024 * sizeof(int));
int32_t *dest_d = (int32_t *)nvshmem_malloc(1024 * sizeof(int));
ret = nvshmemx_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dest_d, src_d, 1024, 0);


Stream-aware communication
NVSHMEM provides peer-to-peer (P2P) on-stream APIs (equivalent to put_nbi_on_stream and signal_on_stream) to efficiently move data and supply low-latency synchronization over P2P-connected GPUs.
One in every of the important thing benefits of those APIs over traditional host-initiated communication is their ability to perform these operations through a zero-SM footprint by leveraging the copy-engine (CE) and stream memory operations capabilities of GPU hardware. Among the underlying CUDA interfaces include:
- Direct GPU-to-GPU transfers: Just like
cudaMemcpyAsync, but with lower latency through optimized data paths - Positive-grained synchronization: Using
cuStreamWriteValue32andcuStreamWaitValue32primitives for efficient signaling between devices without CPU involvement
Along with the P2P on-stream APIs, NVSHMEM also provides popular collective operations (reduce_on_stream, for instance) commonly utilized in AI workloads equivalent to AllReduce. These collectives leverage SHARP, in-network reductions, and multicast acceleration features of NVIDIA NVLINK Switch to enable latency-optimized one-shot and throughput-optimized two-shot AllReduce algorithms. The underlying CUDA interface includes multimem ISA, providing additional advantages of a reduced-SM footprint as primitives equivalent to reductions and broadcast are offloaded to the switch.
Each of those features can exhibit useful compute-communication operations pipelining as most or all the GPU SMs can be found for compute operations, when overlapped in time on the identical CUDA stream.
CUDA Graphs interoperability
NVSHMEM operations could be captured into CUDA Graphs, enabling:
- Amortized kernel launch overhead across multiple iterations
- Optimized execution scheduling by the CUDA runtime
- Seamless composition with other graph-captured operations
This composability is crucial for production training frameworks that depend on CUDA Graphs for performance optimization.
Integrating NVSHMEM and XLA
This section describes how NVSHMEM is integrated into the XLA compiler infrastructure, covering runtime flags, automatic backend selection heuristics, and the compilation flow.
Runtime control through debug options
XLA exposes a runtime flag for dynamic control:
XLA_FLAGS="--xla_gpu_experimental_enable_nvshmem=true"
This flag is defined in xla/debug_options_flags.cc and allows users to enable or disable NVSHMEM without recompilation (default value = false). The “experimental” prefix indicates that the API may evolve because the feature matures.
Automatic backend selection
The CollectiveBackendAssigner pass within the compilation pipeline determines which communication backend to make use of based on workload characteristics. That is where the intelligence of this method lies.
Selection heuristics
The compiler analyzes each collective operation and decides whether to make use of NVSHMEM based on three key criteria:
- Single device: Use NVSHMEM when just one device is visible per process (no network overhead)
- Single partition: Use NVSHMEM when all participating devices within the collective operation are managed by the identical process
- NVLink domain: Use NVSHMEM for intranode communication over NVIDIA NVLink fabric
Moreover, message size heuristics apply:
- AllReduce operations: Only use NVSHMEM if message size < threshold (typically 16 MB). For larger messages, fall back to NCCL which is optimized for bandwidth.
- CollectivePermute operations: All the time use NVSHMEM no matter message size (no threshold applied).
- Rationale: AllReduce advantages from NCCL ring or tree algorithms for big messages, while CollectivePermute point-to-point nature makes NVSHMEM low latency ideal at any size.
JAX framework integration
The strength of this architecture lies in its complete transparency to Python frameworks. A JAX developer writes standard collective operations:
import jax
import jax.numpy as jnp
@jax.jit
def collective_permute_example(x):
# Shift data from each device to the following device in a hoop
axis_name = 'devices'
perm = [(i, (i + 1) % jax.device_count()) for i in range(jax.device_count())]
return jax.lax.ppermute(x, axis_name, perm=perm)
# The compiler routinely selects NVSHMEM when appropriate
result = collective_permute_example(data)
The XLA compiler analyzes this ppermute (collective permute) operation and routinely with the next steps:
- Applies heuristics: single device, single partition, or inside NVLink domain
- Recognizes a CollectivePermute operation (no message size threshold applies)
- Selects NVSHMEM for optimal point-to-point communication
- Generates thunks that invoke NVSHMEM host APIs at runtime
- NVSHMEM host APIs enqueue operations on CUDA streams. For instance:
nvshmemx_float_sum_reduce_on_stream,nvshmemx_float_put_nbi_on_stream
This end-to-end integration signifies that high-level JAX code routinely advantages from NVSHMEM performance without requiring any user-level changes or annotations.
Experimental methodology
To guage NVSHMEM performance advantages, the team conducted experiments on Llama 3 8B across a variety of sequence lengths (64K to 256K tokens) and parallelism configurations. This section details the model setup, hardware configuration, and the metrics used to check NVSHMEM against the NCCL baseline.
Model configuration
The team evaluated NVSHMEM-accelerated context parallelism on the Llama 3 8B model with the next configurations.
- Model: Llama 3 8B
- Precision: BF16
- Context parallel strategy: Ring attention
- Framework: MaxText (JAX-based training framework)
- Hardware: NVIDIA GB200 NVL72
- Docker image: Available through NVIDIA/JAX-Toolbox
- JAX version: JAX 0.6.2 and later
Parallelism configurations
Various combos of parallelism strategies were tested across different sequence lengths (Table 1).
| Sequence length | Nodes | GPUs | Context parallelism | Tensor parallelism | Fully sharded data parallelism | Sequence length per GPU after CP split |
| 64K | 1-4 | 4-16 | 4-16 | 1 | 1-2 | 4K-16K |
| 128K | 2-8 | 8-32 | 8-32 | 1 | 1-2 | 4K-16K |
| 256K | 8-16 | 32-64 | 16-32 | 2 | 1-2 | 8K-16K |
Longer sequences (256K) employed tensor parallelism (TP=2) along with context parallelism to suit the model inside GPU memory constraints.
Communication backend comparison
Each configuration was evaluated with two communication backends:
- NCCL (baseline)
- NVSHMEM-enabled implementation
Measurements:
- TFLOP/s per device: GPU computational throughput
- Step time (seconds): Time per training iteration
- Speedup: Relative performance improvement of NVSHMEM over NCCL
All metrics were averaged across iterations 3-20 (skipping the primary two warmup iterations) and computed from rank 0 logs to make sure consistency.
Performance results
As shown in Table 2, the NVSHMEM performance advantage grows significantly with sequence length:
- 64K sequences: 0.3-3.9% speedup (modest improvement)
- 128K sequences: 0.7-2.4% speedup (consistent improvement)
- 256K sequences: 30.4-36.3% speedup (dramatic improvement)
This scaling behavior aligns with the ring attention communication pattern: longer sequences require more KV tensor exchanges across the ring, amplifying the advantages of the NVSHMEM lower-latency communication.
When scaling across nodes, internode communication latency becomes more critical. NVSHMEM nonblocking host APIs and optimized data paths provide consistent advantages across 8-16 node deployments.
| Sequence length | Nodes | CP | TP | GPUs | Seq/GPU | Default TFLOP/s | NVSHMEM TFLOP/s | Speedup |
| 64K | 1 | 4 | 1 | 4 | 16K | 605.64 | 607.36 | +0.3% |
| 64K | 2 | 8 | 1 | 8 | 8K | 549.92 | 557.17 | +1.3% |
| 64K | 4 | 16 | 1 | 16 | 4K | 482.19 | 501.06 | +3.9% |
| 128K | 2 | 8 | 1 | 8 | 16K | 512.22 | 515.87 | +0.7% |
| 128K | 4 | 16 | 1 | 16 | 8K | 473.58 | 472.46 | -0.2% |
| 128K | 8 | 32 | 1 | 32 | 4K | 420.99 | 431.13 | +2.4% |
| 256K | 8 | 16 | 2 | 32 | 16K | 366.94 | 500.22 | +36.3% |
| 256K | 16 | 32 | 2 | 64 | 8K | 346.33 | 451.70 | +30.4% |
Practical implications
Based on these results, NVSHMEM provides clear benefits for:
- Long-context training: Sequences ≥ 128K tokens where communication becomes a bottleneck
- Multinode deployments: Scaling beyond single-node NVLink domains
- Ring attention and similar patterns: Workloads with fine-grained, latency-sensitive communication
- Hybrid parallelism: Configurations combining CP, TP, and FSDP
The XLA integration makes NVSHMEM accessible to JAX. No user code changes are required, simply use an NVSHMEM-enabled XLA construct and set the suitable environment flags.
Start with long-context model training
Training LLMs with long-context windows requires efficient communication strategies that may handle fine-grained, latency-sensitive data exchanges. The mixing of NVSHMEM into XLA enables transparent acceleration of context parallelism with ring attention, providing as much as 36% speedup for 256K token sequences on Llama 3 8B.
Key takeaways:
- The NVSHMEM nonblocking host APIs and low-latency data paths are ideally suited to the ring attention communication pattern
- XLA compiler integration makes NVSHMEM accessible to high-level frameworks without requiring code changes
- Performance advantages scale with sequence length, with dramatic improvements for sequences ≥ 256K tokens
- Multinode deployments see the biggest gains, making NVSHMEM essential for production long-context training
As context windows proceed to grow, solutions optimizing low-latency communication like NVSHMEM can be crucial for making long-context training practical and cost-effective. We encourage the community to try NVSHMEM-enabled XLA builds in JAX framework and share their experiences with long-context workloads.
To start, try MaxText Framework, NVIDIA/JAX-Toolbox, and openxla/xla on GitHub.
Acknowledgments
We would love to specific our gratitude to NVSHMEM contributors Seth Howell and Akhil Langer.
