Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Google deepmind Mujoco jax jit for MJX

From Leeroopedia
Knowledge Sources
Domains GPU_Computing, Compilation, JAX
Last Updated 2026-02-15 06:00 GMT

Overview

Wrapper documentation for using JAX JIT compilation with MJX physics functions for GPU-accelerated simulation.

Description

jax.jit compiles Python functions into XLA programs. In MJX, it is used to compile mjx.step for efficient GPU execution. The MJX viewer example shows the recommended pattern with donate_argnums=(1,) to donate the Data buffer and .lower().compile() for ahead-of-time compilation.

Usage

Wrap mjx.step with jax.jit before the simulation loop. Use ahead-of-time compilation (.lower().compile()) for predictable first-step latency.

External Reference

Code Reference

Source Location

  • Repository: External (jax library); usage in mujoco
  • File: mjx/mujoco/mjx/viewer.py
  • Lines: 125

Signature

# JAX library function
jax.jit(fun, donate_argnums=None, keep_unused=False) -> CompiledFunction

# MJX usage pattern
step_fn = jax.jit(mjx.step, donate_argnums=(1,), keep_unused=True)
step_fn = step_fn.lower(mx, dx).compile()  # AOT compile

Import

import jax
import mujoco.mjx as mjx

I/O Contract

Inputs

Name Type Required Description
fun Callable Yes Function to compile (e.g., mjx.step)
donate_argnums tuple No Argument indices whose buffers can be reused for output
keep_unused bool No Keep unused arguments in compiled function

Outputs

Name Type Description
return CompiledFunction XLA-compiled function with same signature as input

Usage Examples

import jax
import mujoco.mjx as mjx

mx = mjx.put_model(m)
dx = mjx.make_data(mx)

# JIT compile with buffer donation
step_fn = jax.jit(mjx.step, donate_argnums=(1,), keep_unused=True)

# Ahead-of-time compilation (recommended)
step_fn = step_fn.lower(mx, dx).compile()

# Run compiled step
for _ in range(1000):
    dx = step_fn(mx, dx)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment