Warp v1.10 expands JAX integration with automatic differentiation support and multi-device jax.pmap() compatibility. The tile programming model has been enhanced with axis-specific reductions, component-level indexing, and convenience functions for creating tiles.
Performance has been significantly improved in several areas: BVH operations now support in-place rebuilding for CUDA graphs and configurable leaf sizes, built-in function calls from Python are as much as 70× faster, and extra sparse matrix and FEM operations can now be captured in CUDA graphs.
Additional usability improvements include negative indexing and slicing for arrays, atomic bitwise operations, and latest built-in functions including error functions and kind casting.
Essential: This release removes the warp.sim module (deprecated since v1.8), which has been superseded by the Newton physics engine. See the Announcements section below for migration guidance and other upcoming changes.
For an entire list of changes, see the full changelog.
Latest features
JAX automatic differentiation (experimental)
Warp now supports experimental automatic differentiation with JAX, allowing kernels to take part in JAX automatic differentiation workflows. This feature is contributed by @mehdiataei and builds on earlier work by @jaro-sevcik. It enables computing gradients through Warp kernels using jax.grad() by passing enable_backward=True to jax_kernel().
Key capabilities include:
- Single and multiple output kernels: Compute gradients for kernels with a number of output arrays
- Static input auto-detection: Scalar inputs are robotically treated as static (non-differentiable) arguments
- Vector and matrix arrays: Arrays of composite types like
wp.vec2orwp.mat22are fully supported - Multi-device execution: Compatible with
jax.pmap()for distributed forward and backward passes across multiple GPUs
import jax
from warp.jax_experimental import jax_kernel
@wp.kernel
def my_kernel(a: wp.array(dtype=float), out: wp.array(dtype=float)):
i = wp.tid()
out[i] = a[i] ** 2.0
# Enable automatic differentiation
jax_func = jax_kernel(my_kernel, num_outputs=1, enable_backward=True)
# Compute gradients through the kernel
grad_fn = jax.grad(lambda a: jax.numpy.sum(jax_func(a)[0]))
gradient = grad_fn(input_array) # gradient: [2*a[0], 2*a[1], ...]
This feature is experimental and has some current limitations. See the JAX Automatic Differentiation documentation for complete examples, usage details, and limitations.
Multi-device JAX support with jax.pmap()
Warp now properly supports jax.pmap() and jax.shard_map() for multi-device parallel execution, due to fixes contributed by @chaserileyroberts. Previously, device targeting issues prevented Warp callables from working accurately inside these JAX primitives—JAX would invoke callbacks from multiple threads targeting different devices, but Warp would all the time execute on the default device. The fix ensures proper device coordination by extracting device ordinals from XLA FFI and adding thread synchronization for concurrent callbacks, enabling efficient data-parallel workflows across multiple GPUs.
In-place BVH rebuilding with CUDA graph support
A brand new wp.Bvh.rebuild() method enables rebuilding BVH hierarchies in-place without allocating latest memory. This complements the present refit() method and is especially useful when primitive distributions change significantly.
CUDA graph capture: Unlike making a latest BVH, rebuild() reuses existing buffers, making it protected to capture in CUDA graphs. Previously captured graphs that include queries on the BVH remain valid after rebuilding, enabling high-performance repeated updates without graph re-capture overhead.
Construction algorithms: On CUDA devices, in-place rebuild supports "lbvh" only. On CPU, "sah" and "median" are supported. Defaults are chosen robotically based on the device.
Tile programming enhancements
The tile programming model has been enhanced with latest capabilities to make tile-based computations more expressive and convenient:
Axis-specific reductions
The tile-reduction functions wp.tile_reduce() and wp.tile_sum() now support an optional axis parameter, enabling reductions along a particular dimension of a tile reasonably than reducing the complete tile to a single value. This enhancement brings NumPy-like axis semantics to tile operations.
@wp.kernel
def tile_reduce_axis(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
a = wp.tile_load(x, shape=(4, 8), storage="shared")
# Sum along axis 0, reducing shape from (4, 8) to (8,)
b = wp.tile_sum(a, axis=0)
wp.tile_store(y, b)
x = wp.array(np.arange(32).reshape(4, 8), dtype=float)
# x = [[ 0. 1. 2. 3. 4. 5. 6. 7.]
# [ 8. 9. 10. 11. 12. 13. 14. 15.]
# [16. 17. 18. 19. 20. 21. 22. 23.]
# [24. 25. 26. 27. 28. 29. 30. 31.]]
y = wp.zeros(8, dtype=float)
wp.launch_tiled(tile_reduce_axis, dim=(1,), inputs=[x], outputs=[y], block_dim=32)
# y = [48. 52. 56. 60. 64. 68. 72. 76.] (column sums)
Component-level indexing
Tiles of composite types (vectors, matrices, quaternions) now support component-level indexing and project. You may directly index into individual components using prolonged indexing syntax:
- Vector components:
tile[i][1]extracts the second component of a vector at positioni - Matrix elements:
tile[i][1, 1]accesses the element at row 1, column 1 of a matrix at positioni
This provides more convenient and expressive syntax for working with structured data in tiles.
Creating tiles stuffed with a relentless value
The brand new wp.tile_full() function provides a convenient technique to create tiles initialized with a relentless value, much like NumPy’s np.full():
# Create an 8x8 tile stuffed with 3.14
tile = wp.tile_full(shape=(8, 8), value=3.14, dtype=float)
Latest example
The brand new example_tile_mcgp.py example demonstrates tile-based Monte Carlo methods by implementing a walk-on-spheres algorithm for solving Laplace’s equation on volumetric domains.
Performance improvements
Built-in function calls from Python
Calling Warp built-in functions from Python scope (e.g., wp.normalize(), wp.transform_identity(), matrix arithmetic like mat * mat) is now significantly faster due to optimizations in overload resolution. Previously, each function call would iterate through all overloads, attempt argument binding, and pack parameters into C types until finding a match. Now, Warp caches the resolved overload and parameter packing strategy based on argument types using @functools.lru_cache, eliminating redundant resolution overhead on subsequent calls.
In microbenchmarks, repeated wp.mat44 multiplication at Python scope is as much as 70× faster (~570 μs → ~8 μs), while operations like wp.transform_identity() see 3-4× speedups (~100 μs → ~30 μs). The magnitude of improvement varies by operation complexity, with greater gains for operations requiring dearer overload resolution.
Breaking change: As a part of this optimization, support for passing lists, tuples, and other non-Warp array arguments to built-in functions has been removed. Calls like wp.normalize([1.0, 2.0, 3.0]) must now be written as wp.normalize(wp.vec3(1.0, 2.0, 3.0)). This simplifies the function call path and removes expensive sequence-flattening logic that was incompatible with efficient caching.
Configurable BVH leaf size
wp.Bvh and wp.Mesh now expose tunable leaf_size and bvh_leaf_size parameters, respectively, allowing users to manage the variety of primitives stored in each leaf node for performance optimization. The optimal leaf size relies on the query workload:
- Intersection queries (ray casting, AABB overlap): Smaller leaf sizes (e.g., 1) are generally optimal, reducing unnecessary primitive checks
- Closest point queries: Larger leaf sizes (e.g., 4-8) can improve performance by checking more primitives together and reducing traversal overhead
- Mixed workloads: Moderate values (e.g., 4) provide a balanced trade-off
Behavior change: The default leaf_size for wp.Bvh has modified from 4 (hardcoded) to 1, optimizing for intersection queries that are more common. wp.Mesh retains a default bvh_leaf_size of 4 as a compromise between intersection and closest-point query performance. Users performing primarily closest-point queries may profit from explicitly setting larger leaf sizes.
Sparse matrix operations with CUDA graphs
Sparse matrix operations in warp.sparse can now be captured in CUDA graphs for allocation-free execution. Operations like bsr_axpy(), bsr_assign(), and bsr_set_transpose() preserve matrix topology when using masked=True, while bsr_mm() adds a brand new max_new_nnz parameter that enables specifying an upper sure on latest non-zero blocks for flexible graph capture when sparsity patterns vary inside known bounds.
FEM operations with CUDA graphs
Constructing warp.fem geometry and performance space partitions can now be captured in CUDA graphs by specifying upper bounds on partition sizes: max_cell_count and max_side_count for ExplicitGeometryPartition, and max_node_count for make_space_partition(). Moreover, constructing fields and restrictions is now synchronization-free by default.
Language enhancements
Array indexing and slicing improvements
Warp arrays now support negative indexing and improved slicing behavior, making array manipulation more intuitive and consistent with NumPy conventions.
Negative indexing: Access elements from the tip of an array using negative indices:
@wp.kernel
def use_negative_indexing(arr: wp.array(dtype=float)):
last = arr[-1] # Last element
second_last = arr[-2] # Second-to-last element
Enhanced array slicing: Arrays now support more flexible slicing operations inside kernels, including stride-based access patterns. This works with each regular arrays and tile operations:
@wp.kernel
def tile_load_strided(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
# Load every other element from a 16x16 region into an 8x8 tile
tile = wp.tile_load(input[::2, ::2], shape=(8, 8))
wp.tile_store(output, tile)
input = wp.array(np.arange(256).reshape(16, 16), dtype=float)
output = wp.zeros((8, 8), dtype=float)
wp.launch_tiled(tile_load_strided, dim=(1,), inputs=[input, output], block_dim=32)
# output comprises every other element from input:
# [[ 0. 2. 4. 6. 8. 10. 12. 14.]
# [ 32. 34. 36. 38. 40. 42. 44. 46.]
# [ 64. 66. 68. 70. 72. 74. 76. 78.]
# ...
# [224. 226. 228. 230. 232. 234. 236. 238.]]
Latest built-in functions
- Error functions: Added
wp.erf(),wp.erfc(),wp.erfinv(), andwp.erfcinv()for error function computations - Type casting: Added
wp.solid()to reinterpret values as differing kinds while preserving bit patterns (e.g., reinterpreting float bits as int) - Atomic bitwise operations: Added
wp.atomic_and(),wp.atomic_or(), andwp.atomic_xor()for thread-safe bitwise operations on integers, contributed by @j3soon - Sparse matrix utilities: Added
wp.sparse.bsr_row_index()andwp.sparse.bsr_block_index()as kernel-level functions to efficiently determine which row a given block belongs to without manually looking through the compressed offset array
Bug fixes
AArch64 CPU execution with tiles
Fixed segmentation faults when running tile-based kernels on AArch64 CPUs, affecting platforms including NVIDIA Jetson (Thor, Orin), DGX Spark, Grace Hopper, and Grace Blackwell systems. The fix uses stack memory allocation as an alternative of static memory to work around limitations in LLVM’s JIT compiler.
This transformation is enabled by default on all CPU architectures and might be disabled if needed via wp.config.enable_tiles_in_stack_memory = False. For those who encounter issues which can be resolved by disabling this setting, please report them on our GitHub Issues page.
Note: This primarily affects CPU execution of tile operations, which is less common in Warp workflows but useful for debugging or scenarios wherein GPU memory transfer overhead outweighs compute advantages.
Native library version verification
Warp now performs runtime version checking to detect mismatches between the Python package and native libraries (e.g., warp.dll, warp.so). This helps diagnose issues wherein multiple Warp installations on the identical system may cause the mistaken native libraries to be loaded. When a mismatch is detected, a warning is issued but execution continues. For those who see such warnings, make sure you’re loading Warp from the expected installation location and that your environment doesn’t have conflicting Warp versions.
Announcements
Removal of warp.sim module
The warp.sim module has been removed on this release. This module was formally deprecated in Warp v1.8 (July 2025) and has been superseded by the Newton physics engine, an independent package managed as a Linux Foundation project with a redesigned API focused on robotics and robot learning.
Migration: Users counting on warp.sim should migrate to Newton. For guidance on transitioning from warp.sim to Newton, please seek the advice of the Newton migration guide. The unique deprecation announcement and community discussion might be present in GitHub Discussion #735.
Questions and discussions about Newton must be directed to the Newton Discussions section. Existing issues within the Warp repository concerning warp.sim will likely be closed.
JAX FFI is now the default
The default implementation of jax_kernel() is now based on JAX’s Foreign Function Interface (FFI), which is required for JAX version 0.8 and newer. Most users shouldn’t need to vary their code, because the FFI-based version has been available since Warp 1.7 and provides higher performance through CUDA graph capture. The previous custom call implementation continues to be available as wp.jax_experimental.custom_call.jax_kernel() for users on older JAX versions, however it is deprecated and won’t work with JAX version 0.8 or later.
Internal code reorganization: _src folder
As a part of ongoing efforts to make clear Warp’s public API surface, internal implementation code has been reorganized right into a warp._src subpackage. This transformation helps distinguish between public APIs that users should depend on versus internal implementation details which will change all of sudden.
What this implies for users:
- No immediate breaking changes: All existing imports proceed to work. Modules like
warp.context,warp.types, andwarp.femremain accessible at their current paths through compatibility shims. - Visible in stack traces: Chances are you’ll see
warp._srcpaths in error messages and stack traces (e.g.,warp._src.contextas an alternative ofwarp.context). - Future direction: In upcoming releases, we plan to define and formalize the general public API surface. Once established, public modules will likely be updated to re-export all designated public symbols, after which compatibility shims will likely be removed. Code that imports from internal modules will must be updated to make use of public APIs or explicitly import from
warp._src.*(acknowledging the usage of private APIs).
This reorganization is step one in a multi-phase effort to determine a stable public API. For those who encounter any issues introduced by this reorganization, please report them on our GitHub Issues page.
Upcoming removals
The next features will likely be removed in v1.11 (planned for January 2026):
- Constructing matrices from row vectors: The power to construct matrices by passing row vectors to the matrix constructor (e.g.,
wp.mat22(wp.vec2(1, 2), wp.vec2(3, 4))). Usewp.matrix_from_rows()orwp.matrix_from_cols()as an alternative. This deprecation was originally announced in v1.9 with a planned removal in v1.10, but has been prolonged one release cycle. While kernel-scope usage had been emitting deprecation warnings since v1.9, it was discovered that Python-scope usage lacked proper warnings. Starting in v1.10, each contexts now emit deprecation warnings. graph_compatibleparameter injax_callable(): The booleangraph_compatibleparameter has been deprecated in favor of the brand newgraph_modeparameter which acceptsGraphModeenum values (GraphMode.JAX,GraphMode.WARP, orGraphMode.NONE).
Platform support
- Python 3.14: Warp now supports Python 3.14, expanding compatibility beyond the previous maximum of Python 3.13.
- Intel-based macOS (x86_64): Support for Intel Macs has been removed on this release. Apple Silicon Macs (ARM64) proceed to be fully supported with CPU execution. Users on Intel-based Macs can proceed using Warp 1.9.x or earlier versions.
- Python 3.8: We plan to drop support for Python 3.8 (end-of-life since 2024-10-07) starting with the following minor release (#1019).
Acknowledgments
We also thank the next contributors from outside the core Warp development team:
- @j3soon for adding support for atomic bitwise operations
- @chaserileyroberts for fixing issues with JAX interop on multiple devices
- @mehdiataei and @jaro-sevcik for adding support for JAX automatic differentiation
- @liblaf for improving type annotations for
struct()andoverload()decorators - @manuelkNVDA for adding support for multi-process compilation of the core library
- @boomanaiden154 for contributing fixes to handle upcoming removals in LLVM
