Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI PPOTrainer

From Leeroopedia


Knowledge Sources
Domains Reinforcement Learning, RLHF, Distributed Training
Last Updated 2026-02-09 00:00 GMT

Overview

PPOTrainer is the core trainer class implementing the Proximal Policy Optimization (PPO) algorithm for Reinforcement Learning from Human Feedback (RLHF) in the ColossalChat framework.

Description

The PPOTrainer class extends OLTrainer (Online Trainer) to orchestrate the full PPO training loop, coordinating an actor model, critic model, reward model, and initial reference model. It manages experience collection via NaiveExperienceMaker, computes policy loss with clipped surrogate objectives, value loss for the critic, and optional pre-training (ptx) loss for language modeling regularization. The trainer supports distributed training through ColossalAI's Booster abstraction with gradient accumulation, model offloading during inference, checkpoint saving, and logging via TensorBoard and Weights & Biases.

Usage

Use PPOTrainer when performing RLHF fine-tuning of a language model using the PPO algorithm. It is typically instantiated and invoked from the PPO training script (train_ppo.py) after setting up the actor, critic, reward, and reference models with their respective boosters and optimizers.

Code Reference

Source Location

Signature

class PPOTrainer(OLTrainer):
    def __init__(
        self,
        actor_booster: Booster,
        critic_booster: Booster,
        actor: PreTrainedModel,
        critic: Critic,
        reward_model: Union[RewardModel, RLVRRewardModel],
        initial_model: PreTrainedModel,
        actor_optim: Optimizer,
        critic_optim: Optimizer,
        actor_lr_scheduler: _LRScheduler,
        critic_lr_scheduler: _LRScheduler,
        tokenizer: PreTrainedTokenizerBase,
        kl_coef: float = 0.1,
        ptx_coef: float = 0.9,
        train_batch_size: int = 8,
        buffer_limit: int = 0,
        buffer_cpu_offload: bool = True,
        eps_clip: float = 0.2,
        vf_coef: float = 1.0,
        value_clip: float = 0.2,
        sample_buffer: bool = False,
        dataloader_pin_memory: bool = True,
        offload_inference_models: bool = True,
        apply_loss_mask: bool = True,
        accumulation_steps: int = 1,
        save_interval: int = 0,
        save_dir: str = None,
        use_tp: bool = False,
        coordinator: DistCoordinator = None,
        callbacks: List[Callback] = [],
        **generate_kwargs,
    ) -> None

Import

from coati.trainer.ppo import PPOTrainer

I/O Contract

Inputs

Name Type Required Description
actor_booster Booster Yes Booster instance wrapping the actor model for distributed training
critic_booster Booster Yes Booster instance wrapping the critic model for distributed training
actor PreTrainedModel Yes The actor (policy) model to be trained
critic Critic Yes The critic (value) model to be trained
reward_model Union[RewardModel, RLVRRewardModel] Yes Reward model for scoring generated sequences
initial_model PreTrainedModel Yes Reference model for computing KL divergence
actor_optim Optimizer Yes Optimizer for the actor model
critic_optim Optimizer Yes Optimizer for the critic model
actor_lr_scheduler _LRScheduler Yes Learning rate scheduler for the actor
critic_lr_scheduler _LRScheduler Yes Learning rate scheduler for the critic
tokenizer PreTrainedTokenizerBase Yes Tokenizer for encoding and decoding sequences
kl_coef float No KL divergence coefficient (default: 0.1)
ptx_coef float No Pre-training loss coefficient (default: 0.9)
train_batch_size int No Training batch size (default: 8)
eps_clip float No PPO clipping epsilon for policy loss (default: 0.2)
vf_coef float No Value function loss coefficient (default: 1.0)
value_clip float No Value function clipping range (default: 0.2)
offload_inference_models bool No Whether to offload inference models to CPU during training (default: True)
accumulation_steps int No Number of gradient accumulation steps (default: 1)
save_interval int No Steps between checkpoint saves (default: 0)
save_dir str No Directory for saving checkpoints

Outputs

Name Type Description
None None The trainer modifies models in-place and saves checkpoints to disk

Usage Examples

from coati.trainer import PPOTrainer

trainer = PPOTrainer(
    actor_booster=actor_booster,
    critic_booster=critic_booster,
    actor=actor,
    critic=critic,
    reward_model=reward_model,
    initial_model=ref_model,
    actor_optim=actor_optim,
    critic_optim=critic_optim,
    actor_lr_scheduler=actor_lr_scheduler,
    critic_lr_scheduler=critic_lr_scheduler,
    tokenizer=tokenizer,
    kl_coef=0.1,
    ptx_coef=0.0,
    train_batch_size=16,
    accumulation_steps=8,
    save_dir="./checkpoints",
    coordinator=coordinator,
)

trainer.fit(
    num_episodes=10,
    num_collect_steps=2,
    num_update_steps=5,
    prompt_dataloader=train_prompt_dataloader,
    pretrain_dataloader=train_pretrain_dataloader,
    log_dir="./logs",
    use_wandb=False,
)

Key Methods

_before_fit

Initializes prompt and pretrain dataloaders, sets up TensorBoard writer and optional Weights & Biases logging.

_make_experience

Generates experience tuples from the prompt dataloader using the NaiveExperienceMaker, handling model offloading to GPU for inference.

_training_step

Performs one PPO training step: computes actor policy loss with clipped surrogate objective, optional ptx language modeling loss, and critic value loss. Logs metrics via AccumulativeMeanMeter.

_learn

Executes the learning phase over buffered experiences, supporting both sampled and sequential iteration with distributed sampler coordination.

_save_checkpoint

Saves actor and critic model checkpoints with running optimizer and scheduler states.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment