Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Farama Foundation Gymnasium PendulumFunctional

From Leeroopedia
Knowledge Sources
Domains Reinforcement_Learning, Classic_Control
Last Updated 2026-02-15 03:00 GMT

Overview

A JAX-accelerated functional implementation of the Pendulum swing-up environment, registered as phys2d/Pendulum-v0, with configurable physics parameters and both single and vectorized environment wrappers.

Description

The pendulum (phys2d) module implements the classic inverted pendulum (swing-up) problem as a functional environment using the FuncEnv API for JAX acceleration.

Physics: A pendulum starts from a random position and the goal is to swing it up and keep it balanced in the upright position. The agent applies a continuous torque between -max_torque and +max_torque (default +/-2.0). The dynamics use Euler integration with timestep dt (default 0.05s): thetaacc = (3g / 2l) * sin(theta) + 3 / (ml^2) * u. Angular velocity is clipped to [-max_speed, max_speed] (default +/-8.0).

State: A 2-element JAX array [theta, theta_dot]. Initial state is sampled uniformly from [-pi, pi] for theta and [-1, 1] for theta_dot.

Observation: A 3-element float32 array [cos(theta), sin(theta), theta_dot], providing a continuous, unambiguous representation of the pendulum angle.

PendulumParams (flax dataclass) provides configurable parameters: max_speed (8.0), dt (0.05), g (10.0), m (1.0), l (1.0), high_x (pi), high_y (1.0), and screen_dim (500).

Reward: The cost function is -(theta_normalized^2 + 0.1 * theta_dot^2 + 0.001 * u^2), where theta is normalized to [-pi, pi]. This penalizes deviation from upright, angular velocity, and control effort.

Termination: The episode never terminates naturally (terminal() always returns False). Episodes end only by truncation via max_episode_steps.

Action Space: Box(-2.0, 2.0, shape=(1,), dtype=float32) representing continuous torque.

The module provides three wrapper classes:

  • PendulumJaxEnv -- Single environment with jax.jit.
  • PendulumJaxVectorEnv -- Vectorized environment with jax.jit and jax.vmap.

Rendering uses PyGame with gfxdraw, drawing the pendulum rod, pivot, and end-mass, plus a torque indicator arrow loaded from an assets PNG file.

Usage

Use this environment for continuous control RL algorithms (e.g., DDPG, SAC, PPO) with JAX acceleration. Create via gymnasium.make("phys2d/Pendulum-v0").

Code Reference

Source Location

Signature

class PendulumFunctional(
    FuncEnv[StateType, jax.Array, int, float, bool, RenderStateType, PendulumParams]
):
    max_torque: float = 2.0
    observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32)
    action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32)

    def initial(self, rng, params=PendulumParams) -> StateType
    def transition(self, state, action, rng=None, params=PendulumParams) -> StateType
    def observation(self, state, rng, params=PendulumParams) -> jax.Array
    def reward(self, state, action, next_state, rng, params=PendulumParams) -> float
    def terminal(self, state, rng, params=PendulumParams) -> bool

class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
    def __init__(self, render_mode: str | None = None, **kwargs)

class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
    def __init__(self, num_envs: int, render_mode: str | None = None, max_episode_steps: int = 200, **kwargs)

Import

import gymnasium as gym
env = gym.make("phys2d/Pendulum-v0")

# Vectorized
from gymnasium.envs.phys2d.pendulum import PendulumJaxVectorEnv
vec_env = PendulumJaxVectorEnv(num_envs=16, max_episode_steps=200)

I/O Contract

Inputs

Name Type Required Description
action jax.Array (shape (1,), float32) Yes Torque in range [-2.0, 2.0] (clipped internally)
render_mode str or None No "rgb_array" for pixel rendering
num_envs int Yes (vector only) Number of parallel environments
max_episode_steps int No Maximum steps before truncation (vector only, default 200)

Outputs

Name Type Description
observation jax.Array (shape (3,), float32) [cos(theta), sin(theta), theta_dot]
reward float Negative cost: -(theta^2 + 0.1*thetadot^2 + 0.001*u^2)
terminated bool Always False (no natural termination)
truncated bool Based on max_episode_steps
info dict Empty dictionary

Usage Examples

import gymnasium as gym
import jax.numpy as jnp

# Single JAX Pendulum
env = gym.make("phys2d/Pendulum-v0")
obs, info = env.reset(seed=42)
print(f"Initial obs: cos={obs[0]:.3f}, sin={obs[1]:.3f}, vel={obs[2]:.3f}")

total_reward = 0
for _ in range(200):
    action = jnp.array([2.0 * obs[1]])  # Proportional control
    obs, reward, terminated, truncated, info = env.step(action)
    total_reward += reward
    if terminated or truncated:
        break

print(f"Total reward: {total_reward:.2f}")
env.close()

# Vectorized for parallel training
from gymnasium.envs.phys2d.pendulum import PendulumJaxVectorEnv
vec_env = PendulumJaxVectorEnv(num_envs=64, max_episode_steps=200)
obs, info = vec_env.reset(seed=0)
print(f"Batch observation shape: {obs.shape}")  # (64, 3)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment