Implementation:Farama Foundation Gymnasium CliffWalkingFunctional
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Tabular_Environments |
| Last Updated | 2026-02-15 03:00 GMT |
Overview
A JAX-accelerated functional implementation of the Cliff Walking gridworld environment, registered as tabular/CliffWalking-v0, where an agent navigates a 4x12 grid while avoiding a cliff.
Description
The cliffwalking (tabular) module implements the Cliff Walking problem as a functional environment using the FuncEnv API for JAX acceleration.
Environment Layout: A 4x12 grid where the player starts at position [3, 0] (bottom-left) and must reach the goal at [3, 11] (bottom-right). A cliff runs along positions [3, 1] through [3, 10]. Stepping onto the cliff returns the player to the start position with a -100 reward penalty.
State Representation (EnvState NamedTuple): Contains player_position (2-element jax array for row and column), last_action (integer), and fallen (boolean flag indicating cliff fall).
Action Space: Box(low=0, high=3, dtype=int32) representing four directions: 0=up, 1=right, 2=down, 3=left.
Observation: A single int32 value computed as row * 12 + col, returned in a shape (1,) array. Range is 0 to 47.
Transition Logic: The transition function computes new positions using JAX arithmetic with action-conditional offsets, clips to grid boundaries using jnp.maximum/jnp.minimum, and handles cliff detection by checking if the position is in row 3 with column between 1 and 10. If the player falls off the cliff, position is reset to [3, 0] using element-wise arithmetic to maintain JIT compatibility.
Rewards: -1 per step, -100 for stepping on the cliff (computed as -1 + (-99 * fallen)).
CliffWalkingJaxEnv wraps the functional environment with jax.jit and provides the standard Gymnasium interface, inheriting from FunctionalJaxEnv and EzPickle.
The rendering system uses a detailed RenderStateType NamedTuple containing PyGame surfaces, elf sprites for each direction, mountain background tiles, cliff images, and start/goal markers.
Usage
Use this environment for JAX-accelerated cliff walking simulations. Create via gymnasium.make("tabular/CliffWalking-v0"). Adapted from Example 6.6 in Sutton and Barto's "Reinforcement Learning: An Introduction."
Code Reference
Source Location
- Repository: Farama_Foundation_Gymnasium
- File:
gymnasium/envs/tabular/cliffwalking.py
Signature
class CliffWalkingFunctional(
FuncEnv[EnvState, jax.Array, int, float, bool, RenderStateType, None]
):
action_space = spaces.Box(low=0, high=3, dtype=np.int32)
observation_space = spaces.Box(low=0, high=47, shape=(1,), dtype=np.int32)
def initial(self, rng, params=None) -> EnvState
def transition(self, state, action, key, params=None) -> EnvState
def observation(self, state, params=None) -> jax.Array
def reward(self, state, action, next_state, params=None) -> jax.Array
def terminal(self, state, params=None) -> jax.Array
class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle):
def __init__(self, render_mode: str | None = None, **kwargs)
Import
import gymnasium as gym
env = gym.make("tabular/CliffWalking-v0")
# Or directly
from gymnasium.envs.tabular.cliffwalking import CliffWalkingFunctional, CliffWalkingJaxEnv
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| action | int (0-3) | Yes | 0=up, 1=right, 2=down, 3=left |
| render_mode | str or None | No | "rgb_array" for pixel rendering |
Outputs
| Name | Type | Description |
|---|---|---|
| observation | jax.Array (shape (1,), int32) | Grid position as row * 12 + col (0-47) |
| reward | float | -1 per step, -100 for cliff |
| terminated | bool | True when reaching position [3, 11] |
| truncated | bool | Always False |
| info | dict | Empty dictionary |
Usage Examples
import gymnasium as gym
env = gym.make("tabular/CliffWalking-v0")
obs, info = env.reset(seed=42)
print(f"Starting position: {obs}") # [36] = row 3, col 0
# Take a step right
obs, reward, terminated, truncated, info = env.step(1)
print(f"Position: {obs}, Reward: {reward}") # Fell off cliff -> back to start, reward=-100
env.close()