Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Farama Foundation Gymnasium FunctionalJaxEnv

From Leeroopedia
Knowledge Sources
Domains Reinforcement_Learning, JAX_Acceleration
Last Updated 2026-02-15 03:00 GMT

Overview

Adapter classes that wrap a functional FuncEnv into standard Gymnasium Env and VectorEnv interfaces, providing JAX-based environment execution with PRNG key management.

Description

The functional_jax_env module provides two adapter classes that bridge the functional environment API (FuncEnv) with the standard Gymnasium API:

FunctionalJaxEnv wraps a single FuncEnv instance as a gymnasium.Env. It manages internal state and JAX PRNG keys, splitting the key on each reset() and step() call. The reset(seed=...) method re-initializes the PRNG key when a seed is provided, then calls func_env.initial(). The step(action) method calls func_env.transition(), func_env.observation(), func_env.reward(), and func_env.terminal(), converting the JAX outputs to Python float and bool for the reward and terminated signals. Truncated is always False. Rendering delegates to the functional environment's render_image() method when render_mode="rgb_array".

FunctionalJaxVectorEnv wraps a FuncEnv as a gymnasium.vector.VectorEnv for running multiple environment instances in parallel using jax.vmap. On initialization, it applies jax.vmap to all functional methods via func_env.transform(jax.vmap). It tracks per-environment step counts and handles truncation when max_episode_steps is exceeded. Auto-reset logic is implemented by checking prev_done and resetting specific environments using JAX array indexing (.at[to_reset].set()). The observation and action spaces are batched using batch_space.

Both classes initialize their PRNG keys from a random seed generated by seeding.np_random() when no explicit seed is provided.

Usage

Use FunctionalJaxEnv and FunctionalJaxVectorEnv as base classes for concrete JAX-accelerated environments (e.g., CartPoleJaxEnv, BlackJackJaxEnv). They are not typically instantiated directly by end users but rather through gymnasium.make() with registered environment IDs like "phys2d/CartPole-v1".

Code Reference

Source Location

Signature

class FunctionalJaxEnv(gym.Env, Generic[StateType]):
    def __init__(self, func_env: FuncEnv, metadata: dict | None = None,
                 render_mode: str | None = None, spec: EnvSpec | None = None)
    def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[ObsType, dict]
    def step(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]
    def render(self) -> np.ndarray
    def close(self) -> None

class FunctionalJaxVectorEnv(gym.vector.VectorEnv, Generic[ObsType, ActType, StateType]):
    def __init__(self, func_env: FuncEnv, num_envs: int, max_episode_steps: int = 0,
                 metadata: dict | None = None, render_mode: str | None = None, spec: EnvSpec | None = None)
    def reset(self, *, seed: int | None = None, options: dict | None = None) -> tuple[ObsType, dict]
    def step(self, action: ActType) -> tuple[ObsType, Any, Any, Any, dict]
    def render(self) -> np.ndarray
    def close(self) -> None

Import

from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv

I/O Contract

Inputs

Name Type Required Description
func_env FuncEnv Yes The functional environment to wrap
num_envs int Yes (vector only) Number of parallel environments
max_episode_steps int No Maximum steps before truncation (0 = no limit, vector only)
metadata dict or None No Environment metadata (default includes jax=True)
render_mode str or None No Render mode ("rgb_array" or None)
seed int or None No Seed for PRNG key initialization in reset()
action ActType Yes (step) Action to execute

Outputs

Name Type Description
obs ObsType Observation from the functional environment
reward float (single) / jax.Array (vector) Reward for the transition
terminated bool (single) / jax.Array (vector) Whether the episode terminated
truncated bool (single) / jax.Array (vector) Whether the episode was truncated
info dict Transition info from the functional environment

Usage Examples

import jax
from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional

# Single environment
func_env = CartPoleFunctional()
func_env.transform(jax.jit)
env = FunctionalJaxEnv(func_env, render_mode="rgb_array")
obs, info = env.reset(seed=42)
obs, reward, terminated, truncated, info = env.step(1)

# Vectorized environment
func_env = CartPoleFunctional()
func_env.transform(jax.jit)
vec_env = FunctionalJaxVectorEnv(func_env, num_envs=8, max_episode_steps=200)
obs, info = vec_env.reset(seed=42)
obs, rewards, terminated, truncated, info = vec_env.step(jnp.ones(8, dtype=jnp.int32))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment