Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Google deepmind Mujoco mjx benchmark

From Leeroopedia
Knowledge Sources
Domains Benchmarking, GPU_Computing, Performance
Last Updated 2026-02-15 06:00 GMT

Overview

Concrete tool for benchmarking MJX GPU simulation throughput provided by the MJX test utilities.

Description

The benchmark function loads a model, transfers it to GPU, JIT-compiles a batched step function, runs a timed simulation loop, and reports performance metrics. It measures JIT compilation time separately from runtime, uses jax.lax.scan for efficient multi-step rollout, and supports configurable batch size, step count, solver, and iteration count.

Usage

Call from benchmark scripts (e.g., testspeed.py) to measure MJX performance for a given model and configuration.

Code Reference

Source Location

  • Repository: mujoco
  • File: mjx/mujoco/mjx/_src/test_util.py
  • Lines: 51-105

Signature

def benchmark(
    m: mujoco.MjModel,
    nstep: int = 1000,
    batch_size: int = 1024,
    unroll_steps: int = 1,
    solver: str = 'newton',
    iterations: int = 1,
    ls_iterations: int = 4,
) -> Tuple[float, float, int]:

Import

from mujoco.mjx._src.test_util import benchmark

I/O Contract

Inputs

Name Type Required Description
m mujoco.MjModel Yes CPU model to benchmark
nstep int No Simulation steps per rollout (default 1000)
batch_size int No Parallel environments (default 1024)
unroll_steps int No Steps to unroll in scan body (default 1)
solver str No Constraint solver: 'newton' or 'cg'
iterations int No Solver iterations (default 1)
ls_iterations int No Line search iterations (default 4)

Outputs

Name Type Description
return Tuple[float, float, int] (jit_compile_time_secs, run_time_secs, total_steps)

Usage Examples

import mujoco
from mujoco.mjx._src.test_util import benchmark

m = mujoco.MjModel.from_xml_path("humanoid.xml")

# Run benchmark
jit_time, run_time, total_steps = benchmark(
    m, nstep=1000, batch_size=1024
)

steps_per_sec = total_steps / run_time
print(f"JIT time: {jit_time:.2f}s")
print(f"Run time: {run_time:.2f}s")
print(f"Steps/sec: {steps_per_sec:.0f}")

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment