Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Environment:Farama Foundation Gymnasium JAX Functional Backend

From Leeroopedia
Knowledge Sources
Domains Reinforcement_Learning, Functional_Programming
Last Updated 2026-02-15 03:00 GMT

Overview

JAX runtime with flax and array-api-compat for functional (stateless) environment implementations and JAX-NumPy/Torch interop wrappers.

Description

This environment provides the JAX ecosystem dependencies required for Gymnasium's functional (stateless) environments and JAX interoperability wrappers. Functional environments (`FunctionalJaxEnv`, `FunctionalJaxVectorEnv`) use JAX for hardware-accelerated, JIT-compilable environment dynamics. The wrapper system (`JaxToNumpy`, `JaxToTorch`) enables conversion between JAX arrays and other frameworks. Flax is used for immutable state dataclasses in functional environments (via `flax.struct`). Requires NumPy >= 2.1 for Array API compatibility.

Usage

Use this environment when working with functional JAX-based environments (CartPole-Jax, Pendulum-Jax, Blackjack-Jax, CliffWalking-Jax) or when using JAX interoperability wrappers (`JaxToNumpy`, `JaxToTorch`). Also required for the `gymnasium[array-api]` functionality when JAX is the array backend.

System Requirements

Category Requirement Notes
OS Linux (recommended), macOS JAX GPU support best on Linux
Hardware CPU or GPU JAX supports CPU, CUDA GPU, and TPU
Python >= 3.10 Matches Gymnasium requirement
NumPy >= 2.1 Required for Array API compatibility

Dependencies

Python Packages

  • `jax` >= 0.4.16
  • `jaxlib` >= 0.4.16
  • `flax` >= 0.5.0
  • `array-api-compat` >= 1.11.0
  • `numpy` >= 2.1

Credentials

No credentials required.

Quick Install

# Install Gymnasium with JAX extras
pip install "gymnasium[jax]"

# For GPU-accelerated JAX (CUDA)
pip install "gymnasium[jax]" "jax[cuda12]"

Code Evidence

JAX import guard from `gymnasium/wrappers/jax_to_numpy.py:18-23`:

try:
    import jax.numpy as jnp
except ImportError as e:
    raise DependencyNotInstalled(
        'Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install "gymnasium[jax]"`'
    ) from e

Flax usage in functional environments from `gymnasium/envs/phys2d/cartpole.py:10`:

from flax import struct

JAX functional environments import JAX directly without guards (assume JAX available):

# gymnasium/envs/functional_jax_env.py:7-9
import jax
import jax.numpy as jnp

NumPy version check for Array API from `gymnasium/wrappers/array_conversion.py:44-45`:

if Version(np.__version__) < Version("2.1.0"):
    raise DependencyNotInstalled("Array API functionality requires numpy >= 2.1.0")

Common Errors

Error Message Cause Solution
`DependencyNotInstalled: Jax is not installed` jax/jaxlib missing `pip install "gymnasium[jax]"`
`ImportError: No module named 'flax'` flax not installed `pip install "gymnasium[jax]"` (includes flax)
`Array API functionality requires numpy >= 2.1.0` NumPy too old `pip install "numpy>=2.1"`

Compatibility Notes

  • Functional envs assume JAX: Unlike other optional dependencies, JAX functional environment files import JAX directly without try/except guards. They will fail at import time if JAX is not installed.
  • NumPy 2.1 requirement: The JAX extra requires a newer NumPy (>= 2.1) than the core Gymnasium requirement (>= 1.21.0) for Array API compliance.
  • GPU acceleration: JAX defaults to CPU. For GPU, install the appropriate jaxlib variant (e.g., `jax[cuda12]`).

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment