Implementation:Google deepmind Mujoco jax vmap for MJX
Appearance
| Knowledge Sources | |
|---|---|
| Domains | GPU_Computing, Reinforcement_Learning, JAX |
| Last Updated | 2026-02-15 06:00 GMT |
Overview
Wrapper documentation for using JAX automatic vectorization with MJX to run batched parallel simulations on GPU.
Description
jax.vmap automatically vectorizes a function over a batch dimension. For MJX, it transforms mjx.step from operating on a single simulation to operating on a batch of simulations in parallel. The key configuration is in_axes=(None, 0) which broadcasts the model (shared) and maps over the data batch dimension.
Usage
Wrap mjx.step with jax.vmap(in_axes=(None, 0)) and combine with jax.jit for compiled batched execution.
External Reference
Code Reference
Source Location
- Repository: External (jax library); usage in mujoco
- File: mjx/mujoco/mjx/warp/forward_test.py
- Lines: 80
Signature
# JAX library function
jax.vmap(fun, in_axes=(None, 0)) -> Callable
# MJX usage pattern
batch_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
Import
import jax
import mujoco.mjx as mjx
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| fun | Callable | Yes | Function to vectorize (e.g., mjx.step) |
| in_axes | tuple | Yes | (None, 0) - model shared, data batched |
Outputs
| Name | Type | Description |
|---|---|---|
| return | Callable | Vectorized function: (Model, Data[batch,...]) -> Data[batch,...] |
Usage Examples
import jax
import jax.numpy as jnp
import mujoco
import mujoco.mjx as mjx
m = mujoco.MjModel.from_xml_path("humanoid.xml")
mx = mjx.put_model(m)
# Create batch of 1024 initial states
dx_batch = jax.vmap(lambda _: mjx.make_data(mx))(jnp.arange(1024))
# Compiled batched step
batch_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
# Run 1024 simulations in parallel
for _ in range(1000):
dx_batch = batch_step(mx, dx_batch)
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment