Implementation:Google deepmind Mujoco mjx benchmark
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Benchmarking, GPU_Computing, Performance |
| Last Updated | 2026-02-15 06:00 GMT |
Overview
Concrete tool for benchmarking MJX GPU simulation throughput provided by the MJX test utilities.
Description
The benchmark function loads a model, transfers it to GPU, JIT-compiles a batched step function, runs a timed simulation loop, and reports performance metrics. It measures JIT compilation time separately from runtime, uses jax.lax.scan for efficient multi-step rollout, and supports configurable batch size, step count, solver, and iteration count.
Usage
Call from benchmark scripts (e.g., testspeed.py) to measure MJX performance for a given model and configuration.
Code Reference
Source Location
- Repository: mujoco
- File: mjx/mujoco/mjx/_src/test_util.py
- Lines: 51-105
Signature
def benchmark(
m: mujoco.MjModel,
nstep: int = 1000,
batch_size: int = 1024,
unroll_steps: int = 1,
solver: str = 'newton',
iterations: int = 1,
ls_iterations: int = 4,
) -> Tuple[float, float, int]:
Import
from mujoco.mjx._src.test_util import benchmark
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| m | mujoco.MjModel | Yes | CPU model to benchmark |
| nstep | int | No | Simulation steps per rollout (default 1000) |
| batch_size | int | No | Parallel environments (default 1024) |
| unroll_steps | int | No | Steps to unroll in scan body (default 1) |
| solver | str | No | Constraint solver: 'newton' or 'cg' |
| iterations | int | No | Solver iterations (default 1) |
| ls_iterations | int | No | Line search iterations (default 4) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | Tuple[float, float, int] | (jit_compile_time_secs, run_time_secs, total_steps) |
Usage Examples
import mujoco
from mujoco.mjx._src.test_util import benchmark
m = mujoco.MjModel.from_xml_path("humanoid.xml")
# Run benchmark
jit_time, run_time, total_steps = benchmark(
m, nstep=1000, batch_size=1024
)
steps_per_sec = total_steps / run_time
print(f"JIT time: {jit_time:.2f}s")
print(f"Run time: {run_time:.2f}s")
print(f"Steps/sec: {steps_per_sec:.0f}")
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment