Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Train PPO Script

From Leeroopedia


Knowledge Sources
Domains Reinforcement Learning, RLHF, PPO, Distributed Training
Last Updated 2026-02-09 00:00 GMT

Overview

train_ppo.py is the main training script for Proximal Policy Optimization (PPO) based RLHF, orchestrating the setup of actor, critic, reward, and reference models with ColossalAI's distributed training plugins.

Description

This script implements the end-to-end PPO training workflow for reinforcement learning from human feedback. It initializes four models (actor, critic, reward model, and reference model), configures separate Booster instances for actor and critic, sets up optimizers with cosine annealing warmup scheduling, loads prompt and optional pre-training datasets, and creates a PPOTrainer instance. The script supports both neural reward models and rule-based reward functions (RLVR) for tasks like GSM8K math problem solving. It handles conversation template configuration with response format tags for structured generation, and supports LoRA adaptation for memory-efficient training.

Usage

Use this script for PPO-based RLHF training to align a language model with human preferences or rule-based rewards. It supports Gemini, ZeRO-2, and 3D hybrid parallelism plugins. Launch via torchrun with the desired number of processes.

Code Reference

Source Location

Signature

def train(args) -> None

Import

# This is a standalone training script, typically run directly:
# torchrun --nproc_per_node=<N> train_ppo.py --pretrain <model_path> --prompt_dataset <data_path>

I/O Contract

Inputs

Name Type Required Description
--pretrain str Yes Path to the pretrained actor model
--rm_pretrain str Yes Path to the pretrained reward/critic base model
--prompt_dataset str (nargs=+) Yes Paths to prompt dataset(s) for experience collection
--conversation_template_config str Yes Path to conversation template configuration JSON
--plugin str No Plugin: gemini, gemini_auto, zero2, zero2_cpu, 3d (default: gemini)
--ptx_dataset str (nargs=+) No Paths to pre-training dataset(s) for language modeling regularization
--no_neural_reward_model flag No Use rule-based reward functions instead of neural reward model
--reward_functions str (nargs=+) No Names of reward functions: gsm8k_reward_fn, math_competition_reward_fn
--num_episodes int No Number of PPO training episodes (default: 1)
--num_collect_steps int No Steps for experience collection per episode (default: 2)
--num_update_steps int No Steps for model update per episode (default: 5)
--train_batch_size int No Training batch size (default: 16)
--experience_batch_size int No Experience collection batch size (default: 16)
--lr float No Actor learning rate (default: 9e-6)
--critic_lr float No Critic learning rate (default: 9e-6)
--kl_coef float No KL divergence penalty coefficient (default: 0.1)
--ptx_coef float No Pre-training loss coefficient (default: 0.0)
--max_length int No Maximum total sequence length (default: 2048)
--max_seq_len int No Maximum new tokens to generate (default: 256)
--checkpoint_path str No Actor checkpoint path for resumption
--critic_checkpoint_path str No Critic checkpoint path for resumption
--rm_checkpoint_path str No Reward model checkpoint path
--lora_config str No LoRA configuration file path
--tp int No Tensor parallelism size (default: 1)
--pp int No Pipeline parallelism size (default: 1)
--sp int No Sequence parallelism size (default: 1)

Outputs

Name Type Description
actor_checkpoint directory Actor model checkpoint saved to --save_path/actor/modeling
critic_checkpoint directory Critic model checkpoint saved to --save_path/critic/modeling

Usage Examples

# PPO training with neural reward model:
# torchrun --nproc_per_node=4 train_ppo.py \
#     --pretrain meta-llama/Llama-2-7b \
#     --rm_pretrain meta-llama/Llama-2-7b \
#     --prompt_dataset ./prompts \
#     --conversation_template_config ./template.json \
#     --plugin zero2 \
#     --rm_checkpoint_path ./rm_checkpoint \
#     --num_episodes 10 \
#     --save_path ./ppo_checkpoint

# PPO training with rule-based rewards (RLVR) for math:
# torchrun --nproc_per_node=4 train_ppo.py \
#     --pretrain meta-llama/Llama-2-7b \
#     --rm_pretrain meta-llama/Llama-2-7b \
#     --prompt_dataset ./math_prompts \
#     --conversation_template_config ./template.json \
#     --no_neural_reward_model \
#     --reward_functions gsm8k_reward_fn \
#     --plugin zero2

Key Features

  • Four-Model Architecture - Manages actor, critic, reward model, and reference model with separate Booster instances
  • RLVR Support - Supports rule-based reward functions via RLVRRewardModel for verifiable tasks (e.g., math)
  • Response Format Tags - Configurable think/answer tags for structured generation (e.g., <think>...</think><answer>...</answer>)
  • Conversation Templates - Loads conversation template config with customizable stop tokens
  • Custom Policy for Critic - Uses get_autopolicy for the critic model when using 3D parallelism
  • Separate Checkpoints - Independent checkpoint loading and saving for actor, critic, and reward models

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment