Heuristic:Google deepmind Mujoco MJX Benchmarking Tips
| Knowledge Sources | |
|---|---|
| Domains | Optimization, GPU_Computing |
| Last Updated | 2026-02-15 05:00 GMT |
Overview
Benchmarking methodology for MJX: separate JIT compilation from execution, use `jax.block_until_ready()`, batch with `vmap`/`pmap`, and set `XLA_FLAGS` for GPU GEMM optimization.
Description
MJX benchmarking requires careful methodology because JAX operates asynchronously on GPU/TPU. Naive timing will measure dispatch overhead rather than actual computation. The MJX test utilities demonstrate the correct pattern: separately measure JIT compilation time, use `jax.block_until_ready()` to synchronize before timing, and structure batches with `jax.pmap` (multi-device) wrapping `jax.vmap` (per-device vectorization).
The benchmark also injects the XLA flag `--xla_gpu_triton_gemm_any=True` to enable Triton GEMM for matrix operations, which can significantly improve GPU throughput.
Usage
Use this heuristic when measuring MJX simulation performance, comparing solver configurations, or optimizing batch sizes for GPU/TPU workloads.
The Insight (Rule of Thumb)
- Action 1: Always separate JIT compilation time from execution time. Use `fn.lower(*args).compile()` for explicit compilation.
- Action 2: Always call `jax.block_until_ready(result)` before stopping the timer. GPU operations are asynchronous.
- Action 3: Set `XLA_FLAGS='--xla_gpu_triton_gemm_any=True'` before benchmarking on GPU.
- Value: Default batch_size=1024, nstep=1000 for representative benchmarks.
- Trade-off: Higher `unroll_steps` increases compile time but improves runtime throughput via better GPU utilization.
Reasoning
JAX dispatches operations asynchronously to the accelerator. Without `block_until_ready()`, Python measures only the dispatch time (microseconds) rather than actual compute time (milliseconds). This leads to wildly incorrect benchmarks.
The `pmap`/`vmap` nesting structure ensures optimal batching: `pmap` distributes across devices (multi-GPU), while `vmap` vectorizes within each device. The batch_size should be divisible by `jax.device_count()`.
The Triton GEMM flag enables XLA's Triton-based matrix multiplication, which is often faster than cuBLAS for the matrix sizes encountered in physics simulation.
Code Evidence
Correct benchmark methodology from `mjx/mujoco/mjx/_src/test_util.py:34-48`:
def _measure(fn, *args) -> Tuple[float, float]:
"""Reports jit time and op time for a function."""
beg = time.perf_counter()
compiled_fn = fn.lower(*args).compile()
end = time.perf_counter()
jit_time = end - beg
beg = time.perf_counter()
result = compiled_fn(*args)
jax.block_until_ready(result)
end = time.perf_counter()
run_time = end - beg
return jit_time, run_time
XLA flag injection from `mjx/mujoco/mjx/_src/test_util.py:62-64`:
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
pmap/vmap nesting from `mjx/mujoco/mjx/_src/test_util.py:74-100`:
@jax.pmap
def init(key):
key = jax.random.split(key, batch_size // jax.device_count())
@jax.vmap
def random_init(key):
d = io.make_data(m)
qvel = 0.01 * jax.random.normal(key, shape=(m.nv,))
d = d.replace(qvel=qvel)
return d
return random_init(key)
@jax.pmap
def unroll(d):
@jax.vmap
def step(d, _):
d = forward.step(m, d)
return d, None
d, _ = jax.lax.scan(step, d, None, length=nstep, unroll=unroll_steps)
return d
Memory monitoring from `sample/testspeed.cc:288-290`:
std::printf(" Dynamic memory usage : %.1f%% of %s\n\n",
100 * d[0]->maxuse_arena / (double)(d[0]->narena),
mju_writeNumBytes(d[0]->narena));