Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Farama Foundation Gymnasium Functional Environment API

From Leeroopedia
Revision as of 17:25, 16 February 2026 by Admin (talk | contribs) (Auto-imported from principles/Farama_Foundation_Gymnasium_Functional_Environment_API.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Knowledge Sources
Domains Reinforcement_Learning, Functional_Programming
Last Updated 2026-02-15 03:00 GMT

Overview

A stateless functional environment interface decomposes environment logic into pure functions over explicit state, enabling JIT compilation, automatic differentiation, and hardware-accelerated simulation.

Description

The functional environment API provides an alternative to the traditional stateful environment interface by decomposing the environment into a set of pure functions: initial state generation, state transition, observation extraction, reward computation, and terminal detection. Rather than maintaining hidden internal state as in the standard Env class, the functional API passes state explicitly as an argument and return value. This design aligns with functional programming principles and unlocks powerful capabilities in modern ML frameworks.

The key advantage of the functional approach is compatibility with JAX transformations. Because each function is pure (no side effects, no hidden state), the environment logic can be JIT-compiled for dramatic speed improvements, automatically vectorized across multiple environment instances using vmap, and even differentiated through using JAX's automatic differentiation. This enables gradient-based optimization through the environment dynamics, model-based RL with differentiable simulators, and massively parallel environment simulation on GPUs and TPUs.

The API consists of a base FuncEnv class that defines the interface, and a FunctionalJaxEnv adapter that wraps a functional environment to provide the standard Env interface for compatibility with existing RL algorithms. Concrete implementations exist for several classic environments (CartPole, Pendulum, Blackjack, CliffWalking) as proof-of-concept functional environments. The functional API is currently experimental and subject to change, but it represents the direction toward hardware-accelerated RL training pipelines.

Usage

Use the functional environment API when you need JIT-compiled environment simulation for maximum throughput, when you want to differentiate through environment dynamics for model-based methods, or when you need to run thousands of environment instances in parallel on GPU/TPU hardware. Use the FunctionalJaxEnv wrapper to make functional environments compatible with standard RL training loops. Use the FuncEnv base class as a template when implementing new functional environments.

Theoretical Basis

The functional environment decomposes a Markov Decision Process into pure functions. Given state type 𝒮, observation type 𝒪, action type 𝒜, and random key type 𝒦:

initial:𝒦𝒮

transition:𝒮×𝒜×𝒦𝒮

observation:𝒮×𝒦𝒪

reward:𝒮×𝒜×𝒮×𝒦

terminal:𝒮×𝒦{0,1}

A single environment step is composed as:

def step(state, action, rng):
    rng_transition, rng_obs, rng_reward, rng_terminal = split(rng, 4)
    next_state = transition(state, action, rng_transition)
    observation = observation(next_state, rng_obs)
    reward = reward(state, action, next_state, rng_reward)
    terminated = terminal(next_state, rng_terminal)
    return next_state, observation, reward, terminated

The JAX adapter converts between the functional and stateful interfaces:

class FunctionalJaxEnv(Env):
    def reset(self, seed=None):
        rng = PRNGKey(seed)
        self.state = self.func_env.initial(rng)
        obs = self.func_env.observation(self.state, rng)
        return obs, info

    def step(self, action):
        rng, self.rng = jax.random.split(self.rng)
        next_state = self.func_env.transition(self.state, action, rng)
        obs = self.func_env.observation(next_state, rng)
        reward = self.func_env.reward(self.state, action, next_state, rng)
        terminated = self.func_env.terminal(next_state, rng)
        self.state = next_state
        return obs, reward, terminated, truncated, info

The purity of each function ensures that jax.jit(transition) compiles to optimized XLA code, and jax.vmap(step) automatically vectorizes across a batch of states for parallel simulation.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment