Implementation:Google deepmind Mujoco MJX Math
| Knowledge Sources | |
|---|---|
| Domains | Physics_Simulation, JAX, Linear_Algebra |
| Last Updated | 2026-02-15 04:00 GMT |
Overview
Utility math functions for MJX providing gradient-safe quaternion operations, vector normalization, spatial algebra, and small-matrix arithmetic optimized for JAX.
Description
This module supplies the foundational mathematical primitives used throughout MJX. It includes gradient-safe implementations of vector normalization and division (avoiding NaN gradients at zero), quaternion multiplication, integration, and conversion to rotation matrices and axis-angle representations. Spatial algebra helpers such as transform_motion, motion_cross, and motion_cross_force support rigid-body dynamics computations. The matmul_unroll function provides faster matrix multiplication for small (3x3, 4x4) matrices compared to XLA's general matmul.
Usage
Imported by nearly every other MJX module. These functions are called pervasively during kinematics, dynamics, collision detection, and constraint construction within the mjx.step() pipeline.
Code Reference
Source Location
- Repository: Google_deepmind_Mujoco
- File: mjx/mujoco/mjx/_src/math.py
- Lines: 1-394
Key Functions
def safe_div(num, den) -> Union[float, jax.Array]
def matmul_unroll(a: jax.Array, b: jax.Array) -> jax.Array
def norm(x: jax.Array, axis=None) -> jax.Array
def normalize_with_norm(x: jax.Array, axis=None) -> Tuple[jax.Array, jax.Array]
def normalize(x: jax.Array, axis=None) -> jax.Array
def rotate(vec: jax.Array, quat: jax.Array) -> jax.Array
def quat_inv(q: jp.ndarray) -> jp.ndarray
def quat_sub(u: jax.Array, v: jax.Array) -> jax.Array
def quat_mul(u: jax.Array, v: jax.Array) -> jax.Array
def quat_mul_axis(q: jax.Array, axis: jax.Array) -> jax.Array
def quat_to_mat(q: jax.Array) -> jax.Array
def quat_to_axis_angle(q: jax.Array) -> Tuple[jax.Array, jax.Array]
def axis_angle_to_quat(axis: jax.Array, angle: jax.Array) -> jax.Array
def quat_integrate(q: jax.Array, v: jax.Array, dt: jax.Array) -> jax.Array
def inert_mul(i: jax.Array, v: jax.Array) -> jax.Array
def sign(x: jax.Array) -> jax.Array
def transform_motion(vel: jax.Array, offset: jax.Array, rotmat: jax.Array)
def motion_cross(u, v)
def motion_cross_force(v, f)
def orthogonals(a: jax.Array) -> Tuple[jax.Array, jax.Array]
def make_frame(a: jax.Array) -> jax.Array
def closest_segment_point(a, b, pt) -> jax.Array
def closest_segment_to_segment_points(a0, a1, b0, b1) -> Tuple
Import
from mujoco.mjx._src.math import rotate
from mujoco.mjx._src.math import quat_mul
from mujoco.mjx._src.math import normalize
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x / vec / q / u / v | jax.Array | Yes | Input arrays (vectors, quaternions, matrices) depending on the function |
| axis | int or Tuple | No | Axis along which to compute norms or normalize |
Outputs
| Name | Type | Description |
|---|---|---|
| result | jax.Array | Computed math result (normalized vector, rotated vector, quaternion product, etc.) |