Implementation:Hpcaitech ColossalAI Train PPO Script
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement Learning, RLHF, PPO, Distributed Training |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
train_ppo.py is the main training script for Proximal Policy Optimization (PPO) based RLHF, orchestrating the setup of actor, critic, reward, and reference models with ColossalAI's distributed training plugins.
Description
This script implements the end-to-end PPO training workflow for reinforcement learning from human feedback. It initializes four models (actor, critic, reward model, and reference model), configures separate Booster instances for actor and critic, sets up optimizers with cosine annealing warmup scheduling, loads prompt and optional pre-training datasets, and creates a PPOTrainer instance. The script supports both neural reward models and rule-based reward functions (RLVR) for tasks like GSM8K math problem solving. It handles conversation template configuration with response format tags for structured generation, and supports LoRA adaptation for memory-efficient training.
Usage
Use this script for PPO-based RLHF training to align a language model with human preferences or rule-based rewards. It supports Gemini, ZeRO-2, and 3D hybrid parallelism plugins. Launch via torchrun with the desired number of processes.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/examples/training_scripts/train_ppo.py
- Lines: 1-559
Signature
def train(args) -> None
Import
# This is a standalone training script, typically run directly:
# torchrun --nproc_per_node=<N> train_ppo.py --pretrain <model_path> --prompt_dataset <data_path>
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| --pretrain | str | Yes | Path to the pretrained actor model |
| --rm_pretrain | str | Yes | Path to the pretrained reward/critic base model |
| --prompt_dataset | str (nargs=+) | Yes | Paths to prompt dataset(s) for experience collection |
| --conversation_template_config | str | Yes | Path to conversation template configuration JSON |
| --plugin | str | No | Plugin: gemini, gemini_auto, zero2, zero2_cpu, 3d (default: gemini) |
| --ptx_dataset | str (nargs=+) | No | Paths to pre-training dataset(s) for language modeling regularization |
| --no_neural_reward_model | flag | No | Use rule-based reward functions instead of neural reward model |
| --reward_functions | str (nargs=+) | No | Names of reward functions: gsm8k_reward_fn, math_competition_reward_fn |
| --num_episodes | int | No | Number of PPO training episodes (default: 1) |
| --num_collect_steps | int | No | Steps for experience collection per episode (default: 2) |
| --num_update_steps | int | No | Steps for model update per episode (default: 5) |
| --train_batch_size | int | No | Training batch size (default: 16) |
| --experience_batch_size | int | No | Experience collection batch size (default: 16) |
| --lr | float | No | Actor learning rate (default: 9e-6) |
| --critic_lr | float | No | Critic learning rate (default: 9e-6) |
| --kl_coef | float | No | KL divergence penalty coefficient (default: 0.1) |
| --ptx_coef | float | No | Pre-training loss coefficient (default: 0.0) |
| --max_length | int | No | Maximum total sequence length (default: 2048) |
| --max_seq_len | int | No | Maximum new tokens to generate (default: 256) |
| --checkpoint_path | str | No | Actor checkpoint path for resumption |
| --critic_checkpoint_path | str | No | Critic checkpoint path for resumption |
| --rm_checkpoint_path | str | No | Reward model checkpoint path |
| --lora_config | str | No | LoRA configuration file path |
| --tp | int | No | Tensor parallelism size (default: 1) |
| --pp | int | No | Pipeline parallelism size (default: 1) |
| --sp | int | No | Sequence parallelism size (default: 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| actor_checkpoint | directory | Actor model checkpoint saved to --save_path/actor/modeling |
| critic_checkpoint | directory | Critic model checkpoint saved to --save_path/critic/modeling |
Usage Examples
# PPO training with neural reward model:
# torchrun --nproc_per_node=4 train_ppo.py \
# --pretrain meta-llama/Llama-2-7b \
# --rm_pretrain meta-llama/Llama-2-7b \
# --prompt_dataset ./prompts \
# --conversation_template_config ./template.json \
# --plugin zero2 \
# --rm_checkpoint_path ./rm_checkpoint \
# --num_episodes 10 \
# --save_path ./ppo_checkpoint
# PPO training with rule-based rewards (RLVR) for math:
# torchrun --nproc_per_node=4 train_ppo.py \
# --pretrain meta-llama/Llama-2-7b \
# --rm_pretrain meta-llama/Llama-2-7b \
# --prompt_dataset ./math_prompts \
# --conversation_template_config ./template.json \
# --no_neural_reward_model \
# --reward_functions gsm8k_reward_fn \
# --plugin zero2
Key Features
- Four-Model Architecture - Manages actor, critic, reward model, and reference model with separate Booster instances
- RLVR Support - Supports rule-based reward functions via RLVRRewardModel for verifiable tasks (e.g., math)
- Response Format Tags - Configurable think/answer tags for structured generation (e.g., <think>...</think><answer>...</answer>)
- Conversation Templates - Loads conversation template config with customizable stop tokens
- Custom Policy for Critic - Uses get_autopolicy for the critic model when using 3D parallelism
- Separate Checkpoints - Independent checkpoint loading and saving for actor, critic, and reward models