Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Google deepmind Mujoco JAX Experimental Custom Call

From Leeroopedia
Knowledge Sources
Domains Foreign Function Interface, JAX Integration, CUDA, Custom Call API
Last Updated 2026-02-15 04:00 GMT

Overview

Legacy custom call implementation that wraps NVIDIA Warp kernels as JAX primitives using the XLA custom call API for JAX versions 0.4.25 through 0.7.x.

Description

custom_call.py provides the deprecated jax_kernel() function that creates a JAX primitive from a Warp kernel using XLA's custom call mechanism. It maintains a global registry of Warp kernels (_registered_kernels, _registered_kernel_to_id) and a custom callback (_cc_callback) to dispatch kernel launches from within JAX's execution model. This implementation requires all kernel arguments to be contiguous CUDA arrays with input arguments preceding output arguments in the kernel definition. It has been superseded by the FFI-based implementation for JAX >= 0.8.0.

Usage

This module is used when running MJX with JAX versions 0.4.25 through 0.7.x. For newer JAX versions (>= 0.5.0), it emits deprecation warnings directing users to warp.jax_experimental.ffi.jax_kernel instead. It is no longer compatible with JAX >= 0.8.0.

Code Reference

Source Location

Key Functions

def jax_kernel(kernel, launch_dims=None, quiet=False):
    """Create a Jax primitive from a Warp kernel.

    Deprecated for JAX >= 0.5.0; not supported with JAX >= 0.8.0.
    Use warp.jax_experimental.ffi.jax_kernel instead.

    Args:
        kernel: The Warp kernel to be wrapped.
        launch_dims: Kernel launch dimensions (None = infer from first arg shape).
        quiet: If True, suppress deprecation warnings.
    """

Import

from mujoco.mjx.third_party.warp._src.jax_experimental.custom_call import jax_kernel

I/O Contract

Inputs

Name Type Required Description
kernel wp.Kernel Yes The compiled Warp kernel to wrap as a JAX primitive
launch_dims tuple or None No Explicit kernel launch dimensions; if None, inferred from first argument shape
quiet bool No Suppress deprecation warnings when True (default: False)

Outputs

Name Type Description
jax_primitive Callable A JAX-callable function that dispatches to the Warp kernel via XLA custom call

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment