Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Google deepmind Mujoco MJX Physics Stepping

From Leeroopedia
Knowledge Sources
Domains GPU_Computing, Physics_Simulation, JAX, Reinforcement_Learning
Last Updated 2026-02-15 06:00 GMT

Overview

GPU-accelerated physics simulation step that advances the state of a MuJoCo model using JAX for hardware-accelerated computation.

Description

MJX Physics Stepping is the JAX reimplementation of MuJoCo's mj_step function. It performs the same forward dynamics and integration pipeline but using JAX operations that can be JIT-compiled, vectorized (vmap), and differentiated (grad). This enables massively parallel simulation on GPUs/TPUs for reinforcement learning, where thousands of environments run simultaneously.

The function is a pure JAX function: it takes Model and Data pytrees and returns a new Data with advanced state. This functional style is required for JAX's transformation system.

Usage

Use as the core simulation step in MJX GPU pipelines. Typically wrapped in jax.jit for compilation and jax.vmap for batched execution.

Theoretical Basis

The same physics equations as the C engine, reimplemented in JAX:

M(q)q¨+c(q,q˙)=τ+JTf

The key difference is the functional programming model: instead of modifying Data in-place, mjx.step returns a new Data object. This enables JAX's automatic differentiation and vectorization.

Supported integrators: Euler, RK4, ImplicitFast.

Related Pages

Implemented By

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment