Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Google deepmind Mujoco MJX Dataclasses

From Leeroopedia
Knowledge Sources
Domains Physics_Simulation, JAX, Data_Structures
Last Updated 2026-02-15 04:00 GMT

Overview

Wrapper that automatically registers Python dataclasses as JAX PyTree nodes, providing the foundation for all MJX data structures to be compatible with JAX transformations.

Description

This module provides the PyTreeNode base class and the dataclass() decorator that converts standard Python dataclasses into JAX-compatible PyTree nodes. Fields typed as jax.Array are treated as dynamic PyTree children (traceable by JAX), while all other fields are treated as static metadata (used for tree structure identity). The _tree_replace() utility enables functional updates of nested fields using dot-separated path strings. This design is inspired by Flax's struct.dataclass but uses type annotations rather than field descriptors to determine PyTree membership. NumPy arrays in metadata fields are serialized to raw bytes for hashability.

Usage

Used as the base class for all MJX data structures (Model, Data, Contact, Option, Context, etc.). Any MJX type that needs to pass through jax.jit, jax.vmap, or jax.grad inherits from PyTreeNode.

Code Reference

Source Location

Key Functions

def _jax_in_args(typ) -> bool
def dataclass(clz: _T, register_as_pytree: bool) -> _T

class PyTreeNode:
    """Base class for MJX dataclasses registered as JAX PyTrees."""
    def replace(self, **kwargs) -> 'PyTreeNode'
    def tree_replace(self, params: dict) -> 'PyTreeNode'

def _tree_replace(base, params: dict)

Import

from mujoco.mjx._src.dataclasses import PyTreeNode
from mujoco.mjx._src.dataclasses import dataclass

I/O Contract

Inputs

Name Type Required Description
clz type Yes Python class to wrap as a JAX-compatible dataclass
register_as_pytree bool Yes Whether to register the class with JAX's PyTree system

Outputs

Name Type Description
data_clz type Frozen dataclass registered as a JAX PyTree node with replace() and tree_replace() methods

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment