Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Google deepmind Mujoco jax vmap for MJX

From Leeroopedia
Knowledge Sources
Domains GPU_Computing, Reinforcement_Learning, JAX
Last Updated 2026-02-15 06:00 GMT

Overview

Wrapper documentation for using JAX automatic vectorization with MJX to run batched parallel simulations on GPU.

Description

jax.vmap automatically vectorizes a function over a batch dimension. For MJX, it transforms mjx.step from operating on a single simulation to operating on a batch of simulations in parallel. The key configuration is in_axes=(None, 0) which broadcasts the model (shared) and maps over the data batch dimension.

Usage

Wrap mjx.step with jax.vmap(in_axes=(None, 0)) and combine with jax.jit for compiled batched execution.

External Reference

Code Reference

Source Location

  • Repository: External (jax library); usage in mujoco
  • File: mjx/mujoco/mjx/warp/forward_test.py
  • Lines: 80

Signature

# JAX library function
jax.vmap(fun, in_axes=(None, 0)) -> Callable

# MJX usage pattern
batch_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))

Import

import jax
import mujoco.mjx as mjx

I/O Contract

Inputs

Name Type Required Description
fun Callable Yes Function to vectorize (e.g., mjx.step)
in_axes tuple Yes (None, 0) - model shared, data batched

Outputs

Name Type Description
return Callable Vectorized function: (Model, Data[batch,...]) -> Data[batch,...]

Usage Examples

import jax
import jax.numpy as jnp
import mujoco
import mujoco.mjx as mjx

m = mujoco.MjModel.from_xml_path("humanoid.xml")
mx = mjx.put_model(m)

# Create batch of 1024 initial states
dx_batch = jax.vmap(lambda _: mjx.make_data(mx))(jnp.arange(1024))

# Compiled batched step
batch_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))

# Run 1024 simulations in parallel
for _ in range(1000):
    dx_batch = batch_step(mx, dx_batch)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment