Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Google deepmind Mujoco MJX Benchmarking Tips

From Leeroopedia
Knowledge Sources
Domains Optimization, GPU_Computing
Last Updated 2026-02-15 05:00 GMT

Overview

Benchmarking methodology for MJX: separate JIT compilation from execution, use `jax.block_until_ready()`, batch with `vmap`/`pmap`, and set `XLA_FLAGS` for GPU GEMM optimization.

Description

MJX benchmarking requires careful methodology because JAX operates asynchronously on GPU/TPU. Naive timing will measure dispatch overhead rather than actual computation. The MJX test utilities demonstrate the correct pattern: separately measure JIT compilation time, use `jax.block_until_ready()` to synchronize before timing, and structure batches with `jax.pmap` (multi-device) wrapping `jax.vmap` (per-device vectorization).

The benchmark also injects the XLA flag `--xla_gpu_triton_gemm_any=True` to enable Triton GEMM for matrix operations, which can significantly improve GPU throughput.

Usage

Use this heuristic when measuring MJX simulation performance, comparing solver configurations, or optimizing batch sizes for GPU/TPU workloads.

The Insight (Rule of Thumb)

  • Action 1: Always separate JIT compilation time from execution time. Use `fn.lower(*args).compile()` for explicit compilation.
  • Action 2: Always call `jax.block_until_ready(result)` before stopping the timer. GPU operations are asynchronous.
  • Action 3: Set `XLA_FLAGS='--xla_gpu_triton_gemm_any=True'` before benchmarking on GPU.
  • Value: Default batch_size=1024, nstep=1000 for representative benchmarks.
  • Trade-off: Higher `unroll_steps` increases compile time but improves runtime throughput via better GPU utilization.

Reasoning

JAX dispatches operations asynchronously to the accelerator. Without `block_until_ready()`, Python measures only the dispatch time (microseconds) rather than actual compute time (milliseconds). This leads to wildly incorrect benchmarks.

The `pmap`/`vmap` nesting structure ensures optimal batching: `pmap` distributes across devices (multi-GPU), while `vmap` vectorizes within each device. The batch_size should be divisible by `jax.device_count()`.

The Triton GEMM flag enables XLA's Triton-based matrix multiplication, which is often faster than cuBLAS for the matrix sizes encountered in physics simulation.

Code Evidence

Correct benchmark methodology from `mjx/mujoco/mjx/_src/test_util.py:34-48`:

def _measure(fn, *args) -> Tuple[float, float]:
  """Reports jit time and op time for a function."""
  beg = time.perf_counter()
  compiled_fn = fn.lower(*args).compile()
  end = time.perf_counter()
  jit_time = end - beg

  beg = time.perf_counter()
  result = compiled_fn(*args)
  jax.block_until_ready(result)
  end = time.perf_counter()
  run_time = end - beg

  return jit_time, run_time

XLA flag injection from `mjx/mujoco/mjx/_src/test_util.py:62-64`:

xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

pmap/vmap nesting from `mjx/mujoco/mjx/_src/test_util.py:74-100`:

@jax.pmap
def init(key):
    key = jax.random.split(key, batch_size // jax.device_count())
    @jax.vmap
    def random_init(key):
      d = io.make_data(m)
      qvel = 0.01 * jax.random.normal(key, shape=(m.nv,))
      d = d.replace(qvel=qvel)
      return d
    return random_init(key)

@jax.pmap
def unroll(d):
    @jax.vmap
    def step(d, _):
      d = forward.step(m, d)
      return d, None
    d, _ = jax.lax.scan(step, d, None, length=nstep, unroll=unroll_steps)
    return d

Memory monitoring from `sample/testspeed.cc:288-290`:

std::printf(" Dynamic memory usage : %.1f%% of %s\n\n",
            100 * d[0]->maxuse_arena / (double)(d[0]->narena),
            mju_writeNumBytes(d[0]->narena));

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment