Implementation:Farama Foundation Gymnasium FuncEnv
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Functional_Programming |
| Last Updated | 2026-02-15 03:00 GMT |
Overview
An abstract base class defining the functional environment API where environment state is passed explicitly rather than stored internally, enabling hardware acceleration and functional transformations like JAX's jit and vmap.
Description
The FuncEnv class is a generic base class parameterized by seven type variables: StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType, and Params. It defines a stateless POMDP interface where the environment state is passed as an explicit argument to all methods rather than being stored internally.
The core methods that must be overridden:
- initial(rng, params) -- Generates the initial state using a random number generator.
- transition(state, action, rng, params) -- Computes the next state given current state and action.
- observation(state, rng, params) -- Generates an observation from the current state.
- reward(state, action, next_state, rng, params) -- Computes the reward for a transition.
- terminal(state, rng, params) -- Determines if a state is terminal.
Optional methods with default implementations:
- state_info(state, params) -- Returns an info dict about a state (default: empty dict).
- transition_info(state, action, next_state, params) -- Returns an info dict about a transition (default: empty dict).
The class also supports rendering via render_init, render_image, and render_close methods, and provides a transform method that applies a functional transformation (e.g., jax.jit or jax.vmap) to all core methods simultaneously.
Environment parameters can be customized via the Params type, with defaults provided by get_default_params().
Usage
Use FuncEnv as the base class when implementing environments that need to be compatible with JAX transformations (JIT compilation, automatic vectorization via vmap, automatic differentiation). Concrete implementations include BlackjackFunctional, CartPoleFunctional, PendulumFunctional, and CliffWalkingFunctional.
Code Reference
Source Location
- Repository: Farama_Foundation_Gymnasium
- File:
gymnasium/experimental/functional.py
Signature
class FuncEnv(
Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType, Params]
):
observation_space: Space
action_space: Space
def __init__(self, options: dict[str, Any] | None = None)
def initial(self, rng: Any, params: Params | None = None) -> StateType
def transition(self, state: StateType, action: ActType, rng: Any, params: Params | None = None) -> StateType
def observation(self, state: StateType, rng: Any, params: Params | None = None) -> ObsType
def reward(self, state: StateType, action: ActType, next_state: StateType, rng: Any, params: Params | None = None) -> RewardType
def terminal(self, state: StateType, rng: Any, params: Params | None = None) -> TerminalType
def state_info(self, state: StateType, params: Params | None = None) -> dict
def transition_info(self, state: StateType, action: ActType, next_state: StateType, params: Params | None = None) -> dict
def transform(self, func: Callable[[Callable], Callable])
def render_image(self, state: StateType, render_state: RenderStateType, params: Params | None = None) -> tuple[RenderStateType, np.ndarray]
def render_init(self, params: Params | None = None, **kwargs) -> RenderStateType
def render_close(self, render_state: RenderStateType, params: Params | None = None)
def get_default_params(self, **kwargs) -> Params | None
Import
from gymnasium.experimental.functional import FuncEnv
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| options | dict or None | No | Dictionary of options merged into the instance's __dict__ |
| rng | Any | Yes (for initial, transition, observation, reward, terminal) | Random number generator (typically a JAX PRNGKey) |
| state | StateType | Yes (for transition, observation, reward, terminal) | The current environment state |
| action | ActType | Yes (for transition, reward) | The action taken |
| next_state | StateType | Yes (for reward, transition_info) | The resulting state after the action |
| params | Params or None | No | Environment parameters (defaults used if None) |
Outputs
| Name | Type | Description |
|---|---|---|
| state | StateType | The environment state (from initial or transition) |
| observation | ObsType | The observation generated from a state |
| reward | RewardType | The reward for a transition |
| terminal | TerminalType | Whether the state is terminal |
| info | dict | State or transition information |
Usage Examples
import jax
import jax.numpy as jnp
from gymnasium.experimental.functional import FuncEnv
from gymnasium import spaces
import numpy as np
class SimpleFuncEnv(FuncEnv):
observation_space = spaces.Box(-1.0, 1.0, shape=(2,), dtype=np.float32)
action_space = spaces.Discrete(2)
def initial(self, rng, params=None):
return jax.random.uniform(rng, shape=(2,), minval=-1.0, maxval=1.0)
def transition(self, state, action, rng, params=None):
delta = jnp.where(action == 0, -0.1, 0.1)
return jnp.clip(state + delta, -1.0, 1.0)
def observation(self, state, rng, params=None):
return state
def reward(self, state, action, next_state, rng, params=None):
return jnp.sum(next_state)
def terminal(self, state, rng, params=None):
return False
# Apply JAX JIT compilation
env = SimpleFuncEnv()
env.transform(jax.jit)