Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Farama Foundation Gymnasium FuncEnv

From Leeroopedia
Knowledge Sources
Domains Reinforcement_Learning, Functional_Programming
Last Updated 2026-02-15 03:00 GMT

Overview

An abstract base class defining the functional environment API where environment state is passed explicitly rather than stored internally, enabling hardware acceleration and functional transformations like JAX's jit and vmap.

Description

The FuncEnv class is a generic base class parameterized by seven type variables: StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType, and Params. It defines a stateless POMDP interface where the environment state is passed as an explicit argument to all methods rather than being stored internally.

The core methods that must be overridden:

  • initial(rng, params) -- Generates the initial state using a random number generator.
  • transition(state, action, rng, params) -- Computes the next state given current state and action.
  • observation(state, rng, params) -- Generates an observation from the current state.
  • reward(state, action, next_state, rng, params) -- Computes the reward for a transition.
  • terminal(state, rng, params) -- Determines if a state is terminal.

Optional methods with default implementations:

  • state_info(state, params) -- Returns an info dict about a state (default: empty dict).
  • transition_info(state, action, next_state, params) -- Returns an info dict about a transition (default: empty dict).

The class also supports rendering via render_init, render_image, and render_close methods, and provides a transform method that applies a functional transformation (e.g., jax.jit or jax.vmap) to all core methods simultaneously.

Environment parameters can be customized via the Params type, with defaults provided by get_default_params().

Usage

Use FuncEnv as the base class when implementing environments that need to be compatible with JAX transformations (JIT compilation, automatic vectorization via vmap, automatic differentiation). Concrete implementations include BlackjackFunctional, CartPoleFunctional, PendulumFunctional, and CliffWalkingFunctional.

Code Reference

Source Location

Signature

class FuncEnv(
    Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType, Params]
):
    observation_space: Space
    action_space: Space

    def __init__(self, options: dict[str, Any] | None = None)
    def initial(self, rng: Any, params: Params | None = None) -> StateType
    def transition(self, state: StateType, action: ActType, rng: Any, params: Params | None = None) -> StateType
    def observation(self, state: StateType, rng: Any, params: Params | None = None) -> ObsType
    def reward(self, state: StateType, action: ActType, next_state: StateType, rng: Any, params: Params | None = None) -> RewardType
    def terminal(self, state: StateType, rng: Any, params: Params | None = None) -> TerminalType
    def state_info(self, state: StateType, params: Params | None = None) -> dict
    def transition_info(self, state: StateType, action: ActType, next_state: StateType, params: Params | None = None) -> dict
    def transform(self, func: Callable[[Callable], Callable])
    def render_image(self, state: StateType, render_state: RenderStateType, params: Params | None = None) -> tuple[RenderStateType, np.ndarray]
    def render_init(self, params: Params | None = None, **kwargs) -> RenderStateType
    def render_close(self, render_state: RenderStateType, params: Params | None = None)
    def get_default_params(self, **kwargs) -> Params | None

Import

from gymnasium.experimental.functional import FuncEnv

I/O Contract

Inputs

Name Type Required Description
options dict or None No Dictionary of options merged into the instance's __dict__
rng Any Yes (for initial, transition, observation, reward, terminal) Random number generator (typically a JAX PRNGKey)
state StateType Yes (for transition, observation, reward, terminal) The current environment state
action ActType Yes (for transition, reward) The action taken
next_state StateType Yes (for reward, transition_info) The resulting state after the action
params Params or None No Environment parameters (defaults used if None)

Outputs

Name Type Description
state StateType The environment state (from initial or transition)
observation ObsType The observation generated from a state
reward RewardType The reward for a transition
terminal TerminalType Whether the state is terminal
info dict State or transition information

Usage Examples

import jax
import jax.numpy as jnp
from gymnasium.experimental.functional import FuncEnv
from gymnasium import spaces
import numpy as np

class SimpleFuncEnv(FuncEnv):
    observation_space = spaces.Box(-1.0, 1.0, shape=(2,), dtype=np.float32)
    action_space = spaces.Discrete(2)

    def initial(self, rng, params=None):
        return jax.random.uniform(rng, shape=(2,), minval=-1.0, maxval=1.0)

    def transition(self, state, action, rng, params=None):
        delta = jnp.where(action == 0, -0.1, 0.1)
        return jnp.clip(state + delta, -1.0, 1.0)

    def observation(self, state, rng, params=None):
        return state

    def reward(self, state, action, next_state, rng, params=None):
        return jnp.sum(next_state)

    def terminal(self, state, rng, params=None):
        return False

# Apply JAX JIT compilation
env = SimpleFuncEnv()
env.transform(jax.jit)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment