Implementation:Farama Foundation Gymnasium PendulumFunctional
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Classic_Control |
| Last Updated | 2026-02-15 03:00 GMT |
Overview
A JAX-accelerated functional implementation of the Pendulum swing-up environment, registered as phys2d/Pendulum-v0, with configurable physics parameters and both single and vectorized environment wrappers.
Description
The pendulum (phys2d) module implements the classic inverted pendulum (swing-up) problem as a functional environment using the FuncEnv API for JAX acceleration.
Physics: A pendulum starts from a random position and the goal is to swing it up and keep it balanced in the upright position. The agent applies a continuous torque between -max_torque and +max_torque (default +/-2.0). The dynamics use Euler integration with timestep dt (default 0.05s): thetaacc = (3g / 2l) * sin(theta) + 3 / (ml^2) * u. Angular velocity is clipped to [-max_speed, max_speed] (default +/-8.0).
State: A 2-element JAX array [theta, theta_dot]. Initial state is sampled uniformly from [-pi, pi] for theta and [-1, 1] for theta_dot.
Observation: A 3-element float32 array [cos(theta), sin(theta), theta_dot], providing a continuous, unambiguous representation of the pendulum angle.
PendulumParams (flax dataclass) provides configurable parameters: max_speed (8.0), dt (0.05), g (10.0), m (1.0), l (1.0), high_x (pi), high_y (1.0), and screen_dim (500).
Reward: The cost function is -(theta_normalized^2 + 0.1 * theta_dot^2 + 0.001 * u^2), where theta is normalized to [-pi, pi]. This penalizes deviation from upright, angular velocity, and control effort.
Termination: The episode never terminates naturally (terminal() always returns False). Episodes end only by truncation via max_episode_steps.
Action Space: Box(-2.0, 2.0, shape=(1,), dtype=float32) representing continuous torque.
The module provides three wrapper classes:
- PendulumJaxEnv -- Single environment with
jax.jit. - PendulumJaxVectorEnv -- Vectorized environment with
jax.jitandjax.vmap.
Rendering uses PyGame with gfxdraw, drawing the pendulum rod, pivot, and end-mass, plus a torque indicator arrow loaded from an assets PNG file.
Usage
Use this environment for continuous control RL algorithms (e.g., DDPG, SAC, PPO) with JAX acceleration. Create via gymnasium.make("phys2d/Pendulum-v0").
Code Reference
Source Location
- Repository: Farama_Foundation_Gymnasium
- File:
gymnasium/envs/phys2d/pendulum.py
Signature
class PendulumFunctional(
FuncEnv[StateType, jax.Array, int, float, bool, RenderStateType, PendulumParams]
):
max_torque: float = 2.0
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32)
action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32)
def initial(self, rng, params=PendulumParams) -> StateType
def transition(self, state, action, rng=None, params=PendulumParams) -> StateType
def observation(self, state, rng, params=PendulumParams) -> jax.Array
def reward(self, state, action, next_state, rng, params=PendulumParams) -> float
def terminal(self, state, rng, params=PendulumParams) -> bool
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
def __init__(self, render_mode: str | None = None, **kwargs)
class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
def __init__(self, num_envs: int, render_mode: str | None = None, max_episode_steps: int = 200, **kwargs)
Import
import gymnasium as gym
env = gym.make("phys2d/Pendulum-v0")
# Vectorized
from gymnasium.envs.phys2d.pendulum import PendulumJaxVectorEnv
vec_env = PendulumJaxVectorEnv(num_envs=16, max_episode_steps=200)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| action | jax.Array (shape (1,), float32) | Yes | Torque in range [-2.0, 2.0] (clipped internally) |
| render_mode | str or None | No | "rgb_array" for pixel rendering |
| num_envs | int | Yes (vector only) | Number of parallel environments |
| max_episode_steps | int | No | Maximum steps before truncation (vector only, default 200) |
Outputs
| Name | Type | Description |
|---|---|---|
| observation | jax.Array (shape (3,), float32) | [cos(theta), sin(theta), theta_dot] |
| reward | float | Negative cost: -(theta^2 + 0.1*thetadot^2 + 0.001*u^2) |
| terminated | bool | Always False (no natural termination) |
| truncated | bool | Based on max_episode_steps |
| info | dict | Empty dictionary |
Usage Examples
import gymnasium as gym
import jax.numpy as jnp
# Single JAX Pendulum
env = gym.make("phys2d/Pendulum-v0")
obs, info = env.reset(seed=42)
print(f"Initial obs: cos={obs[0]:.3f}, sin={obs[1]:.3f}, vel={obs[2]:.3f}")
total_reward = 0
for _ in range(200):
action = jnp.array([2.0 * obs[1]]) # Proportional control
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
if terminated or truncated:
break
print(f"Total reward: {total_reward:.2f}")
env.close()
# Vectorized for parallel training
from gymnasium.envs.phys2d.pendulum import PendulumJaxVectorEnv
vec_env = PendulumJaxVectorEnv(num_envs=64, max_episode_steps=200)
obs, info = vec_env.reset(seed=0)
print(f"Batch observation shape: {obs.shape}") # (64, 3)