Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Haosulab ManiSkill PPO Training Loop

From Leeroopedia
Field Value
implementation_name Haosulab_ManiSkill_PPO_Training_Loop
overview Concrete PPO training loop implementing GAE, clipped surrogate loss, and minibatch optimization for ManiSkill environments
type Pattern Doc
domains Reinforcement_Learning, Robotics
last_updated 2026-02-15
related_pages Principle:Haosulab_ManiSkill_PPO_Policy_Optimization

Overview

Description

The PPO training loop is the core optimization routine in the ManiSkill PPO baseline. Each iteration consists of three phases: (1) rollout collection across parallel GPU environments, (2) GAE advantage computation with proper episode boundary handling, and (3) minibatch policy/value optimization with the clipped surrogate objective.

This is a Pattern Doc -- it documents the training loop from the PPO example baseline, not a library API. Users are expected to adapt this code for their specific training needs.

Usage

The training loop is the main entry point of the PPO script. It runs for num_iterations iterations (computed as total_timesteps / batch_size), where each iteration collects num_steps * num_envs transitions and performs update_epochs passes of minibatch optimization.

Code Reference

Field Value
Repository https://github.com/haosulab/ManiSkill
File examples/baselines/ppo/ppo.py
Args dataclass Lines 24-113
Rollout collection Lines 307-332
GAE computation Lines 334-374
Minibatch optimization Lines 384-441
Logging Lines 445-461

Args dataclass (key training hyperparameters):

@dataclass
class Args:
    env_id: str = "PickCube-v1"
    total_timesteps: int = 10000000
    learning_rate: float = 3e-4
    num_envs: int = 512
    num_steps: int = 50            # rollout length per environment
    gamma: float = 0.8             # discount factor
    gae_lambda: float = 0.9        # GAE lambda
    num_minibatches: int = 32      # minibatches per epoch
    update_epochs: int = 4         # optimization epochs per iteration
    norm_adv: bool = True          # normalize advantages
    clip_coef: float = 0.2         # PPO clip epsilon
    clip_vloss: bool = False       # clip value loss
    ent_coef: float = 0.0          # entropy coefficient
    vf_coef: float = 0.5           # value function coefficient
    max_grad_norm: float = 0.5     # gradient clipping norm
    target_kl: float = 0.1         # KL early stopping threshold
    reward_scale: float = 1.0      # reward multiplier
    partial_reset: bool = True     # enable partial resets
    finite_horizon_gae: bool = False  # alternative GAE formulation

    # computed at runtime
    batch_size: int = 0            # num_envs * num_steps
    minibatch_size: int = 0        # batch_size / num_minibatches
    num_iterations: int = 0        # total_timesteps / batch_size

GAE Advantage Computation:

# bootstrap value according to termination and truncation
with torch.no_grad():
    next_value = agent.get_value(next_obs).reshape(1, -1)
    advantages = torch.zeros_like(rewards).to(device)
    lastgaelam = 0
    for t in reversed(range(args.num_steps)):
        if t == args.num_steps - 1:
            next_not_done = 1.0 - next_done
            nextvalues = next_value
        else:
            next_not_done = 1.0 - dones[t + 1]
            nextvalues = values[t + 1]
        real_next_values = next_not_done * nextvalues + final_values[t]
        delta = rewards[t] + args.gamma * real_next_values - values[t]
        advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * next_not_done * lastgaelam
    returns = advantages + values

Minibatch Optimization:

# flatten the batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)

# Optimizing the policy and value network
agent.train()
b_inds = np.arange(args.batch_size)
clipfracs = []
for epoch in range(args.update_epochs):
    np.random.shuffle(b_inds)
    for start in range(0, args.batch_size, args.minibatch_size):
        end = start + args.minibatch_size
        mb_inds = b_inds[start:end]

        _, newlogprob, entropy, newvalue = agent.get_action_and_value(
            b_obs[mb_inds], b_actions[mb_inds]
        )
        logratio = newlogprob - b_logprobs[mb_inds]
        ratio = logratio.exp()

        with torch.no_grad():
            old_approx_kl = (-logratio).mean()
            approx_kl = ((ratio - 1) - logratio).mean()
            clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

        if args.target_kl is not None and approx_kl > args.target_kl:
            break

        mb_advantages = b_advantages[mb_inds]
        if args.norm_adv:
            mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

        # Policy loss
        pg_loss1 = -mb_advantages * ratio
        pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
        pg_loss = torch.max(pg_loss1, pg_loss2).mean()

        # Value loss
        newvalue = newvalue.view(-1)
        v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

        entropy_loss = entropy.mean()
        loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
        optimizer.step()

    if args.target_kl is not None and approx_kl > args.target_kl:
        break

I/O Contract

Training loop inputs (per iteration):

Name Type Shape Description
obs (rollout buffer) torch.Tensor (num_steps, num_envs, obs_dim) Collected observations
actions (rollout buffer) torch.Tensor (num_steps, num_envs, act_dim) Collected actions
logprobs (rollout buffer) torch.Tensor (num_steps, num_envs) Old log-probabilities
rewards (rollout buffer) torch.Tensor (num_steps, num_envs) Scaled rewards
dones (rollout buffer) torch.Tensor (num_steps, num_envs) Done flags
values (rollout buffer) torch.Tensor (num_steps, num_envs) Old value estimates
final_values torch.Tensor (num_steps, num_envs) Bootstrap values at truncation

Training loop outputs (per iteration):

Name Type Description
Updated agent parameters nn.Module state_dict Policy and value network weights after optimization
pg_loss float Policy gradient loss (logged)
v_loss float Value function loss (logged)
entropy_loss float Entropy of the policy (logged)
approx_kl float Approximate KL divergence between old and new policies (logged)
clipfrac float Fraction of samples where clipping was active (logged)
explained_variance float How well the value function explains the returns (logged)

Computed runtime values:

args.batch_size = int(args.num_envs * args.num_steps)      # e.g., 512 * 50 = 25600
args.minibatch_size = int(args.batch_size // args.num_minibatches)  # e.g., 25600 / 32 = 800
args.num_iterations = args.total_timesteps // args.batch_size  # e.g., 10M / 25600 = 390

Usage Examples

Example 1: Run PPO training from command line

python examples/baselines/ppo/ppo.py \
    --env_id="PickCube-v1" \
    --num_envs=512 \
    --total_timesteps=10000000 \
    --learning_rate=3e-4 \
    --gamma=0.8 \
    --gae_lambda=0.9 \
    --update_epochs=4 \
    --num_minibatches=32

Example 2: Compute batch sizes

args = Args(num_envs=512, num_steps=50, num_minibatches=32, total_timesteps=10_000_000)
args.batch_size = args.num_envs * args.num_steps          # 25600
args.minibatch_size = args.batch_size // args.num_minibatches  # 800
args.num_iterations = args.total_timesteps // args.batch_size  # 390

Example 3: GAE computation with final value bootstrapping

# final_values handles truncated episodes correctly:
# - If environment was NOT done at step t: final_values[t] = 0
# - If environment was truncated at step t: final_values[t] = V(final_obs)
# This ensures correct value bootstrapping at episode boundaries

final_values = torch.zeros((num_steps, num_envs), device=device)
# During rollout, when episodes end:
if "final_info" in infos:
    done_mask = infos["_final_info"]
    with torch.no_grad():
        final_values[step, torch.arange(num_envs, device=device)[done_mask]] = \
            agent.get_value(infos["final_observation"][done_mask]).view(-1)

Example 4: Monitoring training progress via logged metrics

# Key metrics logged each iteration:
logger.add_scalar("losses/value_loss", v_loss.item(), global_step)
logger.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
logger.add_scalar("losses/entropy", entropy_loss.item(), global_step)
logger.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
logger.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
logger.add_scalar("losses/explained_variance", explained_var, global_step)
logger.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment