Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Google deepmind Mujoco MJX Math

From Leeroopedia
Knowledge Sources
Domains Physics_Simulation, JAX, Linear_Algebra
Last Updated 2026-02-15 04:00 GMT

Overview

Utility math functions for MJX providing gradient-safe quaternion operations, vector normalization, spatial algebra, and small-matrix arithmetic optimized for JAX.

Description

This module supplies the foundational mathematical primitives used throughout MJX. It includes gradient-safe implementations of vector normalization and division (avoiding NaN gradients at zero), quaternion multiplication, integration, and conversion to rotation matrices and axis-angle representations. Spatial algebra helpers such as transform_motion, motion_cross, and motion_cross_force support rigid-body dynamics computations. The matmul_unroll function provides faster matrix multiplication for small (3x3, 4x4) matrices compared to XLA's general matmul.

Usage

Imported by nearly every other MJX module. These functions are called pervasively during kinematics, dynamics, collision detection, and constraint construction within the mjx.step() pipeline.

Code Reference

Source Location

Key Functions

def safe_div(num, den) -> Union[float, jax.Array]
def matmul_unroll(a: jax.Array, b: jax.Array) -> jax.Array
def norm(x: jax.Array, axis=None) -> jax.Array
def normalize_with_norm(x: jax.Array, axis=None) -> Tuple[jax.Array, jax.Array]
def normalize(x: jax.Array, axis=None) -> jax.Array
def rotate(vec: jax.Array, quat: jax.Array) -> jax.Array
def quat_inv(q: jp.ndarray) -> jp.ndarray
def quat_sub(u: jax.Array, v: jax.Array) -> jax.Array
def quat_mul(u: jax.Array, v: jax.Array) -> jax.Array
def quat_mul_axis(q: jax.Array, axis: jax.Array) -> jax.Array
def quat_to_mat(q: jax.Array) -> jax.Array
def quat_to_axis_angle(q: jax.Array) -> Tuple[jax.Array, jax.Array]
def axis_angle_to_quat(axis: jax.Array, angle: jax.Array) -> jax.Array
def quat_integrate(q: jax.Array, v: jax.Array, dt: jax.Array) -> jax.Array
def inert_mul(i: jax.Array, v: jax.Array) -> jax.Array
def sign(x: jax.Array) -> jax.Array
def transform_motion(vel: jax.Array, offset: jax.Array, rotmat: jax.Array)
def motion_cross(u, v)
def motion_cross_force(v, f)
def orthogonals(a: jax.Array) -> Tuple[jax.Array, jax.Array]
def make_frame(a: jax.Array) -> jax.Array
def closest_segment_point(a, b, pt) -> jax.Array
def closest_segment_to_segment_points(a0, a1, b0, b1) -> Tuple

Import

from mujoco.mjx._src.math import rotate
from mujoco.mjx._src.math import quat_mul
from mujoco.mjx._src.math import normalize

I/O Contract

Inputs

Name Type Required Description
x / vec / q / u / v jax.Array Yes Input arrays (vectors, quaternions, matrices) depending on the function
axis int or Tuple No Axis along which to compute norms or normalize

Outputs

Name Type Description
result jax.Array Computed math result (normalized vector, rotated vector, quaternion product, etc.)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment