Implementation:Farama Foundation Gymnasium FunctionalJaxEnv
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, JAX_Acceleration |
| Last Updated | 2026-02-15 03:00 GMT |
Overview
Adapter classes that wrap a functional FuncEnv into standard Gymnasium Env and VectorEnv interfaces, providing JAX-based environment execution with PRNG key management.
Description
The functional_jax_env module provides two adapter classes that bridge the functional environment API (FuncEnv) with the standard Gymnasium API:
FunctionalJaxEnv wraps a single FuncEnv instance as a gymnasium.Env. It manages internal state and JAX PRNG keys, splitting the key on each reset() and step() call. The reset(seed=...) method re-initializes the PRNG key when a seed is provided, then calls func_env.initial(). The step(action) method calls func_env.transition(), func_env.observation(), func_env.reward(), and func_env.terminal(), converting the JAX outputs to Python float and bool for the reward and terminated signals. Truncated is always False. Rendering delegates to the functional environment's render_image() method when render_mode="rgb_array".
FunctionalJaxVectorEnv wraps a FuncEnv as a gymnasium.vector.VectorEnv for running multiple environment instances in parallel using jax.vmap. On initialization, it applies jax.vmap to all functional methods via func_env.transform(jax.vmap). It tracks per-environment step counts and handles truncation when max_episode_steps is exceeded. Auto-reset logic is implemented by checking prev_done and resetting specific environments using JAX array indexing (.at[to_reset].set()). The observation and action spaces are batched using batch_space.
Both classes initialize their PRNG keys from a random seed generated by seeding.np_random() when no explicit seed is provided.
Usage
Use FunctionalJaxEnv and FunctionalJaxVectorEnv as base classes for concrete JAX-accelerated environments (e.g., CartPoleJaxEnv, BlackJackJaxEnv). They are not typically instantiated directly by end users but rather through gymnasium.make() with registered environment IDs like "phys2d/CartPole-v1".
Code Reference
Source Location
- Repository: Farama_Foundation_Gymnasium
- File:
gymnasium/envs/functional_jax_env.py
Signature
class FunctionalJaxEnv(gym.Env, Generic[StateType]):
def __init__(self, func_env: FuncEnv, metadata: dict | None = None,
render_mode: str | None = None, spec: EnvSpec | None = None)
def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[ObsType, dict]
def step(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]
def render(self) -> np.ndarray
def close(self) -> None
class FunctionalJaxVectorEnv(gym.vector.VectorEnv, Generic[ObsType, ActType, StateType]):
def __init__(self, func_env: FuncEnv, num_envs: int, max_episode_steps: int = 0,
metadata: dict | None = None, render_mode: str | None = None, spec: EnvSpec | None = None)
def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[ObsType, dict]
def step(self, action: ActType) -> tuple[ObsType, Any, Any, Any, dict]
def render(self) -> np.ndarray
def close(self) -> None
Import
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| func_env | FuncEnv | Yes | The functional environment to wrap |
| num_envs | int | Yes (vector only) | Number of parallel environments |
| max_episode_steps | int | No | Maximum steps before truncation (0 = no limit, vector only) |
| metadata | dict or None | No | Environment metadata (default includes jax=True) |
| render_mode | str or None | No | Render mode ("rgb_array" or None) |
| seed | int or None | No | Seed for PRNG key initialization in reset() |
| action | ActType | Yes (step) | Action to execute |
Outputs
| Name | Type | Description |
|---|---|---|
| obs | ObsType | Observation from the functional environment |
| reward | float (single) / jax.Array (vector) | Reward for the transition |
| terminated | bool (single) / jax.Array (vector) | Whether the episode terminated |
| truncated | bool (single) / jax.Array (vector) | Whether the episode was truncated |
| info | dict | Transition info from the functional environment |
Usage Examples
import jax
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
# Single environment
func_env = CartPoleFunctional()
func_env.transform(jax.jit)
env = FunctionalJaxEnv(func_env, render_mode="rgb_array")
obs, info = env.reset(seed=42)
obs, reward, terminated, truncated, info = env.step(1)
# Vectorized environment
func_env = CartPoleFunctional()
func_env.transform(jax.jit)
vec_env = FunctionalJaxVectorEnv(func_env, num_envs=8, max_episode_steps=200)
obs, info = vec_env.reset(seed=42)
obs, rewards, terminated, truncated, info = vec_env.step(jnp.ones(8, dtype=jnp.int32))