Implementation:Google deepmind Mujoco JAX Experimental XLA FFI
| Knowledge Sources | |
|---|---|
| Domains | XLA, Foreign Function Interface, ctypes Bindings, CUDA, Low-level API |
| Last Updated | 2026-02-15 04:00 GMT |
Overview
Low-level ctypes binding module that provides Python access to XLA's C FFI API structures and enumerations for dispatching custom GPU operations.
Description
xla_ffi.py implements Python ctypes wrappers for XLA's Foreign Function Interface C API (as defined in xla/ffi/api/c_api.h). It defines ctypes Structure classes and IntEnum types that mirror the XLA FFI C structures including XLA_FFI_Extension_Base, XLA_FFI_ExecutionStage (INSTANTIATE, PREPARE, INITIALIZE, EXECUTE), and XLA_FFI_DataType (covering PRED, S8-S64, U8-U64, F16, F32, and other numeric types). It also maps JAX/NumPy data types to XLA FFI data type codes and provides Warp-to-XLA type conversion utilities.
Usage
This module is imported by ffi.py (via from .xla_ffi import *) to provide the low-level type definitions needed when registering Warp kernels as XLA FFI handlers. It is not typically used directly by end users but forms the foundation of the JAX-Warp interoperability layer.
Code Reference
Source Location
- Repository: Google_deepmind_Mujoco
- File: mjx/mujoco/mjx/third_party/warp/_src/jax_experimental/xla_ffi.py
- Lines: 1-658
Key Functions
class XLA_FFI_Extension_Type(enum.IntEnum):
Metadata = 1
class XLA_FFI_Extension_Base(ctypes.Structure):
_fields_ = [
("struct_size", ctypes.c_size_t),
("type", ctypes.c_int),
("next", ctypes.POINTER(XLA_FFI_Extension_Base)),
]
class XLA_FFI_ExecutionStage(enum.IntEnum):
INSTANTIATE = 0
PREPARE = 1
INITIALIZE = 2
EXECUTE = 3
class XLA_FFI_DataType(enum.IntEnum):
INVALID = 0
PRED = 1
S8 = 2; S16 = 3; S32 = 4; S64 = 5
U8 = 6; U16 = 7; U32 = 8; U64 = 9
F16 = 10; F32 = 11
...
Import
from mujoco.mjx.third_party.warp._src.jax_experimental.xla_ffi import *
from mujoco.mjx.third_party.warp._src.jax_experimental import xla_ffi
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| (module-level definitions) | N/A | N/A | No runtime inputs; provides type definitions and ctypes structures |
Outputs
| Name | Type | Description |
|---|---|---|
| XLA_FFI_Extension_Base | ctypes.Structure | Base extension structure for XLA FFI call chain |
| XLA_FFI_ExecutionStage | IntEnum | Execution stage constants (INSTANTIATE, PREPARE, INITIALIZE, EXECUTE) |
| XLA_FFI_DataType | IntEnum | XLA data type codes mapping to JAX/NumPy dtypes |
| FfiKernel | ctypes.Structure | Kernel registration structure for XLA FFI dispatch |
Related Pages
- Google_deepmind_Mujoco_JAX_Experimental_FFI - Higher-level FFI module that imports all definitions from this module
- Google_deepmind_Mujoco_JAX_Experimental_Custom_Call - Legacy custom call API that this module's approach replaces