Implementation:Google deepmind Mujoco JAX Experimental Custom Call
| Knowledge Sources | |
|---|---|
| Domains | Foreign Function Interface, JAX Integration, CUDA, Custom Call API |
| Last Updated | 2026-02-15 04:00 GMT |
Overview
Legacy custom call implementation that wraps NVIDIA Warp kernels as JAX primitives using the XLA custom call API for JAX versions 0.4.25 through 0.7.x.
Description
custom_call.py provides the deprecated jax_kernel() function that creates a JAX primitive from a Warp kernel using XLA's custom call mechanism. It maintains a global registry of Warp kernels (_registered_kernels, _registered_kernel_to_id) and a custom callback (_cc_callback) to dispatch kernel launches from within JAX's execution model. This implementation requires all kernel arguments to be contiguous CUDA arrays with input arguments preceding output arguments in the kernel definition. It has been superseded by the FFI-based implementation for JAX >= 0.8.0.
Usage
This module is used when running MJX with JAX versions 0.4.25 through 0.7.x. For newer JAX versions (>= 0.5.0), it emits deprecation warnings directing users to warp.jax_experimental.ffi.jax_kernel instead. It is no longer compatible with JAX >= 0.8.0.
Code Reference
Source Location
- Repository: Google_deepmind_Mujoco
- File: mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/custom_call.py
- Lines: 1-396
Key Functions
def jax_kernel(kernel, launch_dims=None, quiet=False):
"""Create a Jax primitive from a Warp kernel.
Deprecated for JAX >= 0.5.0; not supported with JAX >= 0.8.0.
Use warp.jax_experimental.ffi.jax_kernel instead.
Args:
kernel: The Warp kernel to be wrapped.
launch_dims: Kernel launch dimensions (None = infer from first arg shape).
quiet: If True, suppress deprecation warnings.
"""
Import
from mujoco.mjx.third_party.warp._src.jax_experimental.custom_call import jax_kernel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| kernel | wp.Kernel | Yes | The compiled Warp kernel to wrap as a JAX primitive |
| launch_dims | tuple or None | No | Explicit kernel launch dimensions; if None, inferred from first argument shape |
| quiet | bool | No | Suppress deprecation warnings when True (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| jax_primitive | Callable | A JAX-callable function that dispatches to the Warp kernel via XLA custom call |
Related Pages
- Google_deepmind_Mujoco_JAX_Experimental_FFI - Newer FFI-based replacement for this module
- Google_deepmind_Mujoco_JAX_Experimental_XLA_FFI - XLA FFI ctypes bindings used by the newer implementation
- Google_deepmind_Mujoco_MJX_Warp_FFI - Higher-level MJX FFI layer that builds on these primitives