Implementation:Google deepmind Mujoco MJX Scan
| Knowledge Sources | |
|---|---|
| Domains | Physics_Simulation, JAX, Tree_Traversal |
| Last Updated | 2026-02-15 04:00 GMT |
Overview
Provides scan (fold/map) operations over body joint types and kinematic tree order, enabling efficient JAX-compatible traversal of the model's articulated body tree.
Description
This module implements two main scan patterns: flat() for iterating over bodies grouped by joint type configurations, and body_tree() for traversing the kinematic tree in parent-to-child order with carry state. The flat scan groups bodies by their joint type signature and applies a user function to each group, concatenating results. The body_tree scan performs a tree traversal where each body receives the carry output of its parent. Both functions use optimized index slicing (converting contiguous integer indices to slices for XLA performance) and support flexible input/output type specifications via short string codes (e.g., 'q' for qpos-shaped, 'v' for dof-shaped, 'j' for joint-shaped).
Usage
Used extensively by smooth.py (kinematics, com_pos, crb, rne) and passive.py to iterate over the kinematic tree during forward dynamics computation within mjx.step().
Code Reference
Source Location
- Repository: Google_deepmind_Mujoco
- File: mjx/mujoco/mjx/_src/scan.py
- Lines: 1-497
Key Functions
def _take(obj: Y, idx: np.ndarray) -> Y
def _q_bodyid(m: Model) -> np.ndarray
def _q_jointid(m: Model) -> np.ndarray
def flat(m: Model, fn: Callable, in_types: str, out_type: str, *args) -> Y
def body_tree(m: Model, fn: Callable, in_types: str, out_type: str, *args, reverse: bool = False) -> Y
Import
from mujoco.mjx._src.scan import flat
from mujoco.mjx._src.scan import body_tree
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| m | mjx.Model | Yes | JAX model defining the kinematic tree and joint structure |
| fn | Callable | Yes | User function to apply at each body/joint group |
| in_types | str | Yes | String specifying input array types ('q'=qpos, 'v'=dof, 'j'=joint, 'b'=body, etc.) |
| out_type | str | Yes | String specifying output array type |
| *args | jax.Array | Yes | Input arrays matching in_types specification |
Outputs
| Name | Type | Description |
|---|---|---|
| result | jax.Array or pytree | Concatenated output from applying fn across all body groups or tree nodes |