Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Farama Foundation Gymnasium CartPoleFunctional

From Leeroopedia
Knowledge Sources
Domains Reinforcement_Learning, Classic_Control
Last Updated 2026-02-15 03:00 GMT

Overview

A JAX-accelerated functional implementation of the CartPole balancing environment, registered as phys2d/CartPole-v0 and phys2d/CartPole-v1, with configurable physics parameters and both single and vectorized environment wrappers.

Description

The cartpole (phys2d) module implements the CartPole balancing problem as a functional environment using the FuncEnv API for JAX acceleration.

Physics: A cart moves along a frictionless track with a pole attached by an un-actuated joint. The agent applies a force of +/- force_mag (default 10.0) to the cart. The physics simulation uses Euler integration with timestep tau (default 0.02s) and the standard cartpole equations of motion (reference: Florian, 2005).

State: A 4-element float32 JAX array representing [x, x_dot, theta, theta_dot] (cart position, cart velocity, pole angle, pole angular velocity). Initial state is sampled uniformly from [-x_init, x_init] (default +/-0.05).

CartPoleParams (flax dataclass) provides extensive configurable parameters: gravity (9.8), masscart (1.0), masspole (0.1), length (0.5), force_mag (10.0), tau (0.02), theta_threshold_radians (12 degrees), x_threshold (2.4), x_init (0.05), sutton_barto_reward (False), and screen dimensions.

Termination: The episode ends when the cart position exceeds +/-2.4 or the pole angle exceeds +/-12 degrees.

Reward: Default reward is +1.0 per step (keeping the pole balanced). With sutton_barto_reward=True, the reward is -1.0 on termination and 0.0 otherwise.

Action Space: Discrete(2) where 0=push left, 1=push right.

The module provides three wrapper classes:

  • CartPoleJaxEnv -- Single environment, applies jax.jit.
  • CartPoleJaxVectorEnv -- Vectorized environment, applies jax.jit and jax.vmap, supports max_episode_steps for truncation.

Rendering uses PyGame with gfxdraw for anti-aliased polygon drawing of the cart and pole.

Usage

Use this environment for JAX-accelerated CartPole simulations, particularly useful for parallelized training with CartPoleJaxVectorEnv. Create via gymnasium.make("phys2d/CartPole-v0") (200 step limit) or gymnasium.make("phys2d/CartPole-v1") (500 step limit).

Code Reference

Source Location

Signature

class CartPoleFunctional(
    FuncEnv[StateType, jax.Array, int, float, bool, RenderStateType, CartPoleParams]
):
    observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
    action_space = gym.spaces.Discrete(2)

    def initial(self, rng, params=CartPoleParams) -> StateType
    def transition(self, state, action, rng=None, params=CartPoleParams) -> StateType
    def observation(self, state, rng, params=CartPoleParams) -> jax.Array
    def reward(self, state, action, next_state, rng, params=CartPoleParams) -> jax.Array
    def terminal(self, state, rng, params=CartPoleParams) -> jax.Array

class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
    def __init__(self, render_mode: str | None = None, **kwargs)

class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
    def __init__(self, num_envs: int, render_mode: str | None = None, max_episode_steps: int = 200, **kwargs)

Import

import gymnasium as gym
env = gym.make("phys2d/CartPole-v1")

# Vectorized
from gymnasium.envs.phys2d.cartpole import CartPoleJaxVectorEnv
vec_env = CartPoleJaxVectorEnv(num_envs=16, max_episode_steps=500)

I/O Contract

Inputs

Name Type Required Description
action int (0 or 1) Yes 0=push left, 1=push right
render_mode str or None No "rgb_array" for pixel rendering
num_envs int Yes (vector only) Number of parallel environments
max_episode_steps int No Maximum steps before truncation (vector only, default 200)

Outputs

Name Type Description
observation jax.Array (shape (4,), float32) [x, x_dot, theta, theta_dot]
reward float +1.0 per step (or -1.0 on termination in Sutton-Barto mode)
terminated bool True when cart or pole exceeds threshold
truncated bool False (single) or based on max_episode_steps (vector)
info dict Empty dictionary

Usage Examples

import gymnasium as gym

# Single JAX CartPole
env = gym.make("phys2d/CartPole-v1")
obs, info = env.reset(seed=42)
print(f"Initial state: {obs}")

total_reward = 0
done = False
while not done:
    action = 1 if obs[2] > 0 else 0  # Simple angle-based policy
    obs, reward, terminated, truncated, info = env.step(action)
    total_reward += reward
    done = terminated or truncated

print(f"Total reward: {total_reward}")
env.close()

# Vectorized for fast parallel simulation
from gymnasium.envs.phys2d.cartpole import CartPoleJaxVectorEnv
vec_env = CartPoleJaxVectorEnv(num_envs=32, max_episode_steps=500)
obs, info = vec_env.reset(seed=42)
print(f"Batch observation shape: {obs.shape}")  # (32, 4)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment