Implementation:Google deepmind Mujoco MJX Dataclasses
| Knowledge Sources | |
|---|---|
| Domains | Physics_Simulation, JAX, Data_Structures |
| Last Updated | 2026-02-15 04:00 GMT |
Overview
Wrapper that automatically registers Python dataclasses as JAX PyTree nodes, providing the foundation for all MJX data structures to be compatible with JAX transformations.
Description
This module provides the PyTreeNode base class and the dataclass() decorator that converts standard Python dataclasses into JAX-compatible PyTree nodes. Fields typed as jax.Array are treated as dynamic PyTree children (traceable by JAX), while all other fields are treated as static metadata (used for tree structure identity). The _tree_replace() utility enables functional updates of nested fields using dot-separated path strings. This design is inspired by Flax's struct.dataclass but uses type annotations rather than field descriptors to determine PyTree membership. NumPy arrays in metadata fields are serialized to raw bytes for hashability.
Usage
Used as the base class for all MJX data structures (Model, Data, Contact, Option, Context, etc.). Any MJX type that needs to pass through jax.jit, jax.vmap, or jax.grad inherits from PyTreeNode.
Code Reference
Source Location
- Repository: Google_deepmind_Mujoco
- File: mjx/mujoco/mjx/_src/dataclasses.py
- Lines: 1-181
Key Functions
def _jax_in_args(typ) -> bool
def dataclass(clz: _T, register_as_pytree: bool) -> _T
class PyTreeNode:
"""Base class for MJX dataclasses registered as JAX PyTrees."""
def replace(self, **kwargs) -> 'PyTreeNode'
def tree_replace(self, params: dict) -> 'PyTreeNode'
def _tree_replace(base, params: dict)
Import
from mujoco.mjx._src.dataclasses import PyTreeNode
from mujoco.mjx._src.dataclasses import dataclass
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| clz | type | Yes | Python class to wrap as a JAX-compatible dataclass |
| register_as_pytree | bool | Yes | Whether to register the class with JAX's PyTree system |
Outputs
| Name | Type | Description |
|---|---|---|
| data_clz | type | Frozen dataclass registered as a JAX PyTree node with replace() and tree_replace() methods |