Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Google deepmind Mujoco MJX Scan

From Leeroopedia
Knowledge Sources
Domains Physics_Simulation, JAX, Tree_Traversal
Last Updated 2026-02-15 04:00 GMT

Overview

Provides scan (fold/map) operations over body joint types and kinematic tree order, enabling efficient JAX-compatible traversal of the model's articulated body tree.

Description

This module implements two main scan patterns: flat() for iterating over bodies grouped by joint type configurations, and body_tree() for traversing the kinematic tree in parent-to-child order with carry state. The flat scan groups bodies by their joint type signature and applies a user function to each group, concatenating results. The body_tree scan performs a tree traversal where each body receives the carry output of its parent. Both functions use optimized index slicing (converting contiguous integer indices to slices for XLA performance) and support flexible input/output type specifications via short string codes (e.g., 'q' for qpos-shaped, 'v' for dof-shaped, 'j' for joint-shaped).

Usage

Used extensively by smooth.py (kinematics, com_pos, crb, rne) and passive.py to iterate over the kinematic tree during forward dynamics computation within mjx.step().

Code Reference

Source Location

Key Functions

def _take(obj: Y, idx: np.ndarray) -> Y
def _q_bodyid(m: Model) -> np.ndarray
def _q_jointid(m: Model) -> np.ndarray
def flat(m: Model, fn: Callable, in_types: str, out_type: str, *args) -> Y
def body_tree(m: Model, fn: Callable, in_types: str, out_type: str, *args, reverse: bool = False) -> Y

Import

from mujoco.mjx._src.scan import flat
from mujoco.mjx._src.scan import body_tree

I/O Contract

Inputs

Name Type Required Description
m mjx.Model Yes JAX model defining the kinematic tree and joint structure
fn Callable Yes User function to apply at each body/joint group
in_types str Yes String specifying input array types ('q'=qpos, 'v'=dof, 'j'=joint, 'b'=body, etc.)
out_type str Yes String specifying output array type
*args jax.Array Yes Input arrays matching in_types specification

Outputs

Name Type Description
result jax.Array or pytree Concatenated output from applying fn across all body groups or tree nodes

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment