Environment:Farama Foundation Gymnasium JAX Functional Backend
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Functional_Programming |
| Last Updated | 2026-02-15 03:00 GMT |
Overview
JAX runtime with flax and array-api-compat for functional (stateless) environment implementations and JAX-NumPy/Torch interop wrappers.
Description
This environment provides the JAX ecosystem dependencies required for Gymnasium's functional (stateless) environments and JAX interoperability wrappers. Functional environments (`FunctionalJaxEnv`, `FunctionalJaxVectorEnv`) use JAX for hardware-accelerated, JIT-compilable environment dynamics. The wrapper system (`JaxToNumpy`, `JaxToTorch`) enables conversion between JAX arrays and other frameworks. Flax is used for immutable state dataclasses in functional environments (via `flax.struct`). Requires NumPy >= 2.1 for Array API compatibility.
Usage
Use this environment when working with functional JAX-based environments (CartPole-Jax, Pendulum-Jax, Blackjack-Jax, CliffWalking-Jax) or when using JAX interoperability wrappers (`JaxToNumpy`, `JaxToTorch`). Also required for the `gymnasium[array-api]` functionality when JAX is the array backend.
System Requirements
| Category | Requirement | Notes |
|---|---|---|
| OS | Linux (recommended), macOS | JAX GPU support best on Linux |
| Hardware | CPU or GPU | JAX supports CPU, CUDA GPU, and TPU |
| Python | >= 3.10 | Matches Gymnasium requirement |
| NumPy | >= 2.1 | Required for Array API compatibility |
Dependencies
Python Packages
- `jax` >= 0.4.16
- `jaxlib` >= 0.4.16
- `flax` >= 0.5.0
- `array-api-compat` >= 1.11.0
- `numpy` >= 2.1
Credentials
No credentials required.
Quick Install
# Install Gymnasium with JAX extras
pip install "gymnasium[jax]"
# For GPU-accelerated JAX (CUDA)
pip install "gymnasium[jax]" "jax[cuda12]"
Code Evidence
JAX import guard from `gymnasium/wrappers/jax_to_numpy.py:18-23`:
try:
import jax.numpy as jnp
except ImportError as e:
raise DependencyNotInstalled(
'Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install "gymnasium[jax]"`'
) from e
Flax usage in functional environments from `gymnasium/envs/phys2d/cartpole.py:10`:
from flax import struct
JAX functional environments import JAX directly without guards (assume JAX available):
# gymnasium/envs/functional_jax_env.py:7-9
import jax
import jax.numpy as jnp
NumPy version check for Array API from `gymnasium/wrappers/array_conversion.py:44-45`:
if Version(np.__version__) < Version("2.1.0"):
raise DependencyNotInstalled("Array API functionality requires numpy >= 2.1.0")
Common Errors
| Error Message | Cause | Solution |
|---|---|---|
| `DependencyNotInstalled: Jax is not installed` | jax/jaxlib missing | `pip install "gymnasium[jax]"` |
| `ImportError: No module named 'flax'` | flax not installed | `pip install "gymnasium[jax]"` (includes flax) |
| `Array API functionality requires numpy >= 2.1.0` | NumPy too old | `pip install "numpy>=2.1"` |
Compatibility Notes
- Functional envs assume JAX: Unlike other optional dependencies, JAX functional environment files import JAX directly without try/except guards. They will fail at import time if JAX is not installed.
- NumPy 2.1 requirement: The JAX extra requires a newer NumPy (>= 2.1) than the core Gymnasium requirement (>= 1.21.0) for Array API compliance.
- GPU acceleration: JAX defaults to CPU. For GPU, install the appropriate jaxlib variant (e.g., `jax[cuda12]`).