Implementation:Google deepmind Mujoco mjx step
Appearance
| Knowledge Sources | |
|---|---|
| Domains | GPU_Computing, Physics_Simulation, JAX |
| Last Updated | 2026-02-15 06:00 GMT |
Overview
Concrete tool for advancing MJX simulation by one timestep provided by the MJX forward dynamics module.
Description
The mjx.step function performs one complete physics timestep using JAX operations: forward kinematics, collision detection, constraint solving, and time integration. It dispatches to the appropriate backend (JAX, Warp) based on the model/data implementation type. The function is compatible with jax.jit, jax.vmap, and jax.grad.
Usage
Wrap in jax.jit for compilation. Use jax.vmap with in_axes=(None, 0) for batched simulation (shared model, batched data).
Code Reference
Source Location
- Repository: mujoco
- File: mjx/mujoco/mjx/_src/forward.py
- Lines: 458-475
Signature
def step(m: Model, d: Data) -> Data:
"""Advance simulation."""
Import
import mujoco.mjx as mjx
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| m | types.Model | Yes | MJX model on device |
| d | types.Data | Yes | Current simulation state on device |
Outputs
| Name | Type | Description |
|---|---|---|
| return | types.Data | New data with advanced state (time += timestep) |
Usage Examples
Single Step
import mujoco
import mujoco.mjx as mjx
import jax
m = mujoco.MjModel.from_xml_path("humanoid.xml")
mx = mjx.put_model(m)
dx = mjx.make_data(mx)
# JIT compile and step
step_fn = jax.jit(mjx.step)
dx = step_fn(mx, dx)
Batched Simulation
import jax
# Vectorize over data (model shared)
batch_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
# Create batch of initial states
dx_batch = jax.vmap(lambda _: mjx.make_data(mx))(jax.numpy.arange(1024))
# Step all environments in parallel
dx_batch = batch_step(mx, dx_batch)
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment