Implementation:Google deepmind Mujoco jax jit for MJX
Appearance
| Knowledge Sources | |
|---|---|
| Domains | GPU_Computing, Compilation, JAX |
| Last Updated | 2026-02-15 06:00 GMT |
Overview
Wrapper documentation for using JAX JIT compilation with MJX physics functions for GPU-accelerated simulation.
Description
jax.jit compiles Python functions into XLA programs. In MJX, it is used to compile mjx.step for efficient GPU execution. The MJX viewer example shows the recommended pattern with donate_argnums=(1,) to donate the Data buffer and .lower().compile() for ahead-of-time compilation.
Usage
Wrap mjx.step with jax.jit before the simulation loop. Use ahead-of-time compilation (.lower().compile()) for predictable first-step latency.
External Reference
Code Reference
Source Location
- Repository: External (jax library); usage in mujoco
- File: mjx/mujoco/mjx/viewer.py
- Lines: 125
Signature
# JAX library function
jax.jit(fun, donate_argnums=None, keep_unused=False) -> CompiledFunction
# MJX usage pattern
step_fn = jax.jit(mjx.step, donate_argnums=(1,), keep_unused=True)
step_fn = step_fn.lower(mx, dx).compile() # AOT compile
Import
import jax
import mujoco.mjx as mjx
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| fun | Callable | Yes | Function to compile (e.g., mjx.step) |
| donate_argnums | tuple | No | Argument indices whose buffers can be reused for output |
| keep_unused | bool | No | Keep unused arguments in compiled function |
Outputs
| Name | Type | Description |
|---|---|---|
| return | CompiledFunction | XLA-compiled function with same signature as input |
Usage Examples
import jax
import mujoco.mjx as mjx
mx = mjx.put_model(m)
dx = mjx.make_data(mx)
# JIT compile with buffer donation
step_fn = jax.jit(mjx.step, donate_argnums=(1,), keep_unused=True)
# Ahead-of-time compilation (recommended)
step_fn = step_fn.lower(mx, dx).compile()
# Run compiled step
for _ in range(1000):
dx = step_fn(mx, dx)
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment