Implementation:Farama Foundation Gymnasium CartPoleFunctional
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Classic_Control |
| Last Updated | 2026-02-15 03:00 GMT |
Overview
A JAX-accelerated functional implementation of the CartPole balancing environment, registered as phys2d/CartPole-v0 and phys2d/CartPole-v1, with configurable physics parameters and both single and vectorized environment wrappers.
Description
The cartpole (phys2d) module implements the CartPole balancing problem as a functional environment using the FuncEnv API for JAX acceleration.
Physics: A cart moves along a frictionless track with a pole attached by an un-actuated joint. The agent applies a force of +/- force_mag (default 10.0) to the cart. The physics simulation uses Euler integration with timestep tau (default 0.02s) and the standard cartpole equations of motion (reference: Florian, 2005).
State: A 4-element float32 JAX array representing [x, x_dot, theta, theta_dot] (cart position, cart velocity, pole angle, pole angular velocity). Initial state is sampled uniformly from [-x_init, x_init] (default +/-0.05).
CartPoleParams (flax dataclass) provides extensive configurable parameters: gravity (9.8), masscart (1.0), masspole (0.1), length (0.5), force_mag (10.0), tau (0.02), theta_threshold_radians (12 degrees), x_threshold (2.4), x_init (0.05), sutton_barto_reward (False), and screen dimensions.
Termination: The episode ends when the cart position exceeds +/-2.4 or the pole angle exceeds +/-12 degrees.
Reward: Default reward is +1.0 per step (keeping the pole balanced). With sutton_barto_reward=True, the reward is -1.0 on termination and 0.0 otherwise.
Action Space: Discrete(2) where 0=push left, 1=push right.
The module provides three wrapper classes:
- CartPoleJaxEnv -- Single environment, applies
jax.jit. - CartPoleJaxVectorEnv -- Vectorized environment, applies
jax.jitandjax.vmap, supportsmax_episode_stepsfor truncation.
Rendering uses PyGame with gfxdraw for anti-aliased polygon drawing of the cart and pole.
Usage
Use this environment for JAX-accelerated CartPole simulations, particularly useful for parallelized training with CartPoleJaxVectorEnv. Create via gymnasium.make("phys2d/CartPole-v0") (200 step limit) or gymnasium.make("phys2d/CartPole-v1") (500 step limit).
Code Reference
Source Location
- Repository: Farama_Foundation_Gymnasium
- File:
gymnasium/envs/phys2d/cartpole.py
Signature
class CartPoleFunctional(
FuncEnv[StateType, jax.Array, int, float, bool, RenderStateType, CartPoleParams]
):
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
action_space = gym.spaces.Discrete(2)
def initial(self, rng, params=CartPoleParams) -> StateType
def transition(self, state, action, rng=None, params=CartPoleParams) -> StateType
def observation(self, state, rng, params=CartPoleParams) -> jax.Array
def reward(self, state, action, next_state, rng, params=CartPoleParams) -> jax.Array
def terminal(self, state, rng, params=CartPoleParams) -> jax.Array
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
def __init__(self, render_mode: str | None = None, **kwargs)
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
def __init__(self, num_envs: int, render_mode: str | None = None, max_episode_steps: int = 200, **kwargs)
Import
import gymnasium as gym
env = gym.make("phys2d/CartPole-v1")
# Vectorized
from gymnasium.envs.phys2d.cartpole import CartPoleJaxVectorEnv
vec_env = CartPoleJaxVectorEnv(num_envs=16, max_episode_steps=500)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| action | int (0 or 1) | Yes | 0=push left, 1=push right |
| render_mode | str or None | No | "rgb_array" for pixel rendering |
| num_envs | int | Yes (vector only) | Number of parallel environments |
| max_episode_steps | int | No | Maximum steps before truncation (vector only, default 200) |
Outputs
| Name | Type | Description |
|---|---|---|
| observation | jax.Array (shape (4,), float32) | [x, x_dot, theta, theta_dot] |
| reward | float | +1.0 per step (or -1.0 on termination in Sutton-Barto mode) |
| terminated | bool | True when cart or pole exceeds threshold |
| truncated | bool | False (single) or based on max_episode_steps (vector) |
| info | dict | Empty dictionary |
Usage Examples
import gymnasium as gym
# Single JAX CartPole
env = gym.make("phys2d/CartPole-v1")
obs, info = env.reset(seed=42)
print(f"Initial state: {obs}")
total_reward = 0
done = False
while not done:
action = 1 if obs[2] > 0 else 0 # Simple angle-based policy
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
done = terminated or truncated
print(f"Total reward: {total_reward}")
env.close()
# Vectorized for fast parallel simulation
from gymnasium.envs.phys2d.cartpole import CartPoleJaxVectorEnv
vec_env = CartPoleJaxVectorEnv(num_envs=32, max_episode_steps=500)
obs, info = vec_env.reset(seed=42)
print(f"Batch observation shape: {obs.shape}") # (32, 4)