Implementation:Haosulab ManiSkill PPO Training Loop
| Field | Value |
|---|---|
| implementation_name | Haosulab_ManiSkill_PPO_Training_Loop |
| overview | Concrete PPO training loop implementing GAE, clipped surrogate loss, and minibatch optimization for ManiSkill environments |
| type | Pattern Doc |
| domains | Reinforcement_Learning, Robotics |
| last_updated | 2026-02-15 |
| related_pages | Principle:Haosulab_ManiSkill_PPO_Policy_Optimization |
Overview
Description
The PPO training loop is the core optimization routine in the ManiSkill PPO baseline. Each iteration consists of three phases: (1) rollout collection across parallel GPU environments, (2) GAE advantage computation with proper episode boundary handling, and (3) minibatch policy/value optimization with the clipped surrogate objective.
This is a Pattern Doc -- it documents the training loop from the PPO example baseline, not a library API. Users are expected to adapt this code for their specific training needs.
Usage
The training loop is the main entry point of the PPO script. It runs for num_iterations iterations (computed as total_timesteps / batch_size), where each iteration collects num_steps * num_envs transitions and performs update_epochs passes of minibatch optimization.
Code Reference
| Field | Value |
|---|---|
| Repository | https://github.com/haosulab/ManiSkill |
| File | examples/baselines/ppo/ppo.py
|
| Args dataclass | Lines 24-113 |
| Rollout collection | Lines 307-332 |
| GAE computation | Lines 334-374 |
| Minibatch optimization | Lines 384-441 |
| Logging | Lines 445-461 |
Args dataclass (key training hyperparameters):
@dataclass
class Args:
env_id: str = "PickCube-v1"
total_timesteps: int = 10000000
learning_rate: float = 3e-4
num_envs: int = 512
num_steps: int = 50 # rollout length per environment
gamma: float = 0.8 # discount factor
gae_lambda: float = 0.9 # GAE lambda
num_minibatches: int = 32 # minibatches per epoch
update_epochs: int = 4 # optimization epochs per iteration
norm_adv: bool = True # normalize advantages
clip_coef: float = 0.2 # PPO clip epsilon
clip_vloss: bool = False # clip value loss
ent_coef: float = 0.0 # entropy coefficient
vf_coef: float = 0.5 # value function coefficient
max_grad_norm: float = 0.5 # gradient clipping norm
target_kl: float = 0.1 # KL early stopping threshold
reward_scale: float = 1.0 # reward multiplier
partial_reset: bool = True # enable partial resets
finite_horizon_gae: bool = False # alternative GAE formulation
# computed at runtime
batch_size: int = 0 # num_envs * num_steps
minibatch_size: int = 0 # batch_size / num_minibatches
num_iterations: int = 0 # total_timesteps / batch_size
GAE Advantage Computation:
# bootstrap value according to termination and truncation
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
next_not_done = 1.0 - next_done
nextvalues = next_value
else:
next_not_done = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
real_next_values = next_not_done * nextvalues + final_values[t]
delta = rewards[t] + args.gamma * real_next_values - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * next_not_done * lastgaelam
returns = advantages + values
Minibatch Optimization:
# flatten the batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
# Optimizing the policy and value network
agent.train()
b_inds = np.arange(args.batch_size)
clipfracs = []
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
mb_inds = b_inds[start:end]
_, newlogprob, entropy, newvalue = agent.get_action_and_value(
b_obs[mb_inds], b_actions[mb_inds]
)
logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()
with torch.no_grad():
old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
if args.target_kl is not None and approx_kl > args.target_kl:
break
mb_advantages = b_advantages[mb_inds]
if args.norm_adv:
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# Value loss
newvalue = newvalue.view(-1)
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
entropy_loss = entropy.mean()
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
if args.target_kl is not None and approx_kl > args.target_kl:
break
I/O Contract
Training loop inputs (per iteration):
| Name | Type | Shape | Description |
|---|---|---|---|
| obs (rollout buffer) | torch.Tensor |
(num_steps, num_envs, obs_dim) |
Collected observations |
| actions (rollout buffer) | torch.Tensor |
(num_steps, num_envs, act_dim) |
Collected actions |
| logprobs (rollout buffer) | torch.Tensor |
(num_steps, num_envs) |
Old log-probabilities |
| rewards (rollout buffer) | torch.Tensor |
(num_steps, num_envs) |
Scaled rewards |
| dones (rollout buffer) | torch.Tensor |
(num_steps, num_envs) |
Done flags |
| values (rollout buffer) | torch.Tensor |
(num_steps, num_envs) |
Old value estimates |
| final_values | torch.Tensor |
(num_steps, num_envs) |
Bootstrap values at truncation |
Training loop outputs (per iteration):
| Name | Type | Description |
|---|---|---|
| Updated agent parameters | nn.Module state_dict |
Policy and value network weights after optimization |
| pg_loss | float |
Policy gradient loss (logged) |
| v_loss | float |
Value function loss (logged) |
| entropy_loss | float |
Entropy of the policy (logged) |
| approx_kl | float |
Approximate KL divergence between old and new policies (logged) |
| clipfrac | float |
Fraction of samples where clipping was active (logged) |
| explained_variance | float |
How well the value function explains the returns (logged) |
Computed runtime values:
args.batch_size = int(args.num_envs * args.num_steps) # e.g., 512 * 50 = 25600
args.minibatch_size = int(args.batch_size // args.num_minibatches) # e.g., 25600 / 32 = 800
args.num_iterations = args.total_timesteps // args.batch_size # e.g., 10M / 25600 = 390
Usage Examples
Example 1: Run PPO training from command line
python examples/baselines/ppo/ppo.py \
--env_id="PickCube-v1" \
--num_envs=512 \
--total_timesteps=10000000 \
--learning_rate=3e-4 \
--gamma=0.8 \
--gae_lambda=0.9 \
--update_epochs=4 \
--num_minibatches=32
Example 2: Compute batch sizes
args = Args(num_envs=512, num_steps=50, num_minibatches=32, total_timesteps=10_000_000)
args.batch_size = args.num_envs * args.num_steps # 25600
args.minibatch_size = args.batch_size // args.num_minibatches # 800
args.num_iterations = args.total_timesteps // args.batch_size # 390
Example 3: GAE computation with final value bootstrapping
# final_values handles truncated episodes correctly:
# - If environment was NOT done at step t: final_values[t] = 0
# - If environment was truncated at step t: final_values[t] = V(final_obs)
# This ensures correct value bootstrapping at episode boundaries
final_values = torch.zeros((num_steps, num_envs), device=device)
# During rollout, when episodes end:
if "final_info" in infos:
done_mask = infos["_final_info"]
with torch.no_grad():
final_values[step, torch.arange(num_envs, device=device)[done_mask]] = \
agent.get_value(infos["final_observation"][done_mask]).view(-1)
Example 4: Monitoring training progress via logged metrics
# Key metrics logged each iteration:
logger.add_scalar("losses/value_loss", v_loss.item(), global_step)
logger.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
logger.add_scalar("losses/entropy", entropy_loss.item(), global_step)
logger.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
logger.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
logger.add_scalar("losses/explained_variance", explained_var, global_step)
logger.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
Related Pages
- Principle:Haosulab_ManiSkill_PPO_Policy_Optimization -- The principle this implementation realizes
- Implementation:Haosulab_ManiSkill_PPO_Agent_Network -- The agent network being optimized
- Implementation:Haosulab_ManiSkill_BaseEnv_Step_Reset -- The environment step/reset calls during rollout collection
- Implementation:Haosulab_ManiSkill_PPO_Eval_Loop -- Evaluation that runs periodically during training
- Environment:Haosulab_ManiSkill_GPU_CUDA_Simulation
- Heuristic:Haosulab_ManiSkill_GPU_Memory_Buffer_Tuning
- Heuristic:Haosulab_ManiSkill_Num_Envs_Backend_Selection