Implementation:Google deepmind Mujoco MJX Warp Collision Driver
| Knowledge Sources | |
|---|---|
| Domains | Physics Simulation, Collision Detection, GPU Computing, JAX-Warp Bridge |
| Last Updated | 2026-02-15 04:00 GMT |
Overview
Auto-generated bridge module that dispatches MuJoCo collision detection computations from JAX to NVIDIA Warp GPU kernels via the MJX FFI layer.
Description
collision_driver.py is an auto-generated file that serves as a shim between the MJX JAX-based physics pipeline and the MuJoCo Warp collision detection backend. It initializes placeholder instances of key MuJoCo Warp dataclasses (Model, Data, Option, Statistic, Contact, Constraint) and defines a _collision_shim function decorated with @ffi.format_args_for_warp that accepts the full set of model geometry parameters (geom types, sizes, AABBs, friction, margins, mesh data, heightfield data) and simulation data arrays needed for broad-phase and narrow-phase collision detection.
Usage
This module is called internally by the MJX pipeline when collision detection is required during a forward simulation step. It bridges JAX array inputs to Warp kernel launches, enabling GPU-accelerated collision computations within a JAX-differentiable simulation graph.
Code Reference
Source Location
- Repository: Google_deepmind_Mujoco
- File: mjx/mujoco/mjx/warp/collision_driver.py
- Lines: 1-410
Key Functions
@ffi.format_args_for_warp
def _collision_shim(
# Model
nworld: int,
block_dim: mjwp_types.BlockDim,
geom_aabb: wp.array3d(dtype=wp.vec3),
geom_condim: wp.array(dtype=int),
geom_dataid: wp.array(dtype=int),
geom_friction: wp.array2d(dtype=wp.vec3),
geom_gap: wp.array2d(dtype=float),
geom_margin: wp.array2d(dtype=float),
geom_pair_type_count: tuple[int, ...],
...
# Data (contact outputs)
)
Import
from mujoco.mjx.warp import collision_driver
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| nworld | int | Yes | Number of parallel simulation worlds |
| block_dim | mjwp_types.BlockDim | Yes | Block dimension configuration for tiled kernel launches |
| geom_aabb | wp.array3d(vec3) | Yes | Axis-aligned bounding boxes for all geoms |
| geom_type | wp.array(int) | Yes | Geom type identifiers (sphere, capsule, box, mesh, etc.) |
| geom_size | wp.array2d(vec3) | Yes | Geom size parameters |
| geom_friction | wp.array2d(vec3) | Yes | Friction coefficients per geom |
| mesh_face | wp.array(vec3i) | Yes | Mesh face indices for mesh-type geoms |
| hfield_data | wp.array(float) | Yes | Heightfield elevation data |
Outputs
| Name | Type | Description |
|---|---|---|
| contact (via Data) | Contact dataclass fields | Detected collision contacts including positions, normals, depths, and friction parameters |
Related Pages
- Google_deepmind_Mujoco_MJX_Warp_FFI - FFI helper functions used by the collision driver
- Google_deepmind_Mujoco_MJX_Warp_Forward - Forward dynamics pipeline that invokes collision detection
- Google_deepmind_Mujoco_MJX_Warp_Types - Type definitions for Model, Data, Contact, and Constraint