Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI DetachedPPOTrainer

From Leeroopedia


Knowledge Sources
Domains Reinforcement Learning, Distributed Training, RLHF, PPO
Last Updated 2026-02-09 00:00 GMT

Overview

A Ray remote actor implementing the PPO (Proximal Policy Optimization) algorithm as a detached RLHF trainer for ColossalChat, with support for multiple parallelism strategies and LoRA weight updates.

Description

DetachedPPOTrainer extends DetachedTrainer and is decorated as a @ray.remote actor with concurrency groups for buffer operations, model I/O, and compute. It initializes actor and critic models via factory functions, configures PolicyLoss and ValueLoss with configurable clip coefficients, and selects optimizers based on the training strategy (HybridAdam for Gemini/LowLevelZero strategies, standard Adam otherwise).

The training_step method performs a full PPO update: computing action log probabilities, calculating actor loss via clipped surrogate objective, performing critic loss via value clipping, and executing backward passes and optimizer steps through the strategy. The _update_remote_makers method synchronizes model weights to remote experience makers in a chunked fashion, with optional LoRA-only weight transfer for efficiency.

Usage

Deploy as a named Ray actor in a distributed RLHF setup. Provide factory functions for the strategy and models, along with the names of ExperienceMakerHolder actors. The trainer receives experience data via its buffer methods and runs the PPO training loop when fit is called.

Code Reference

Source Location

Signature

@ray.remote(
    concurrency_groups={
        "buffer_length": 1, "buffer_append": 1,
        "buffer_sample": 1, "model_io": 1, "compute": 1,
    }
)
class DetachedPPOTrainer(DetachedTrainer):
    def __init__(
        self,
        experience_maker_holder_name_list: List[str],
        strategy_fn: Callable[[], Strategy],
        model_fn: Callable[[], Tuple[Actor, Critic]],
        env_info: Dict[str, str] = None,
        train_batch_size: int = 8,
        buffer_limit: int = 0,
        eps_clip: float = 0.2,
        value_clip: float = 0.4,
        dataloader_pin_memory: bool = True,
        callbacks: List[TrainerCallback] = [],
        eval_performance: bool = False,
        debug: bool = False,
        update_lora_weights: bool = False,
    ) -> None: ...

    def training_step(self, experience: Experience) -> Dict[str, float]: ...
    def _update_remote_makers(self, fully_update: bool = False, **config): ...
    def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None: ...
    def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None: ...

Import

from coati.ray.detached_trainer_ppo import DetachedPPOTrainer

I/O Contract

Inputs

Name Type Required Description
experience_maker_holder_name_list List[str] Yes Names of remote ExperienceMakerHolder Ray actors
strategy_fn Callable[[], Strategy] Yes Factory function returning the training strategy
model_fn Callable[[], Tuple[Actor, Critic]] Yes Factory function returning actor and critic models
env_info Dict[str, str] No Environment variables for distributed setup (default None)
train_batch_size int No Batch size for training (default 8)
buffer_limit int No Maximum replay buffer size (default 0)
eps_clip float No Policy loss clip coefficient (default 0.2)
value_clip float No Value loss clip coefficient (default 0.4)
dataloader_pin_memory bool No Pin memory for data loader (default True)
callbacks List[TrainerCallback] No Callback instances (default [])
eval_performance bool No Enable performance evaluation callback (default False)
debug bool No Enable debug logging (default False)
update_lora_weights bool No Transfer only LoRA weights to makers (default False)

Outputs

Name Type Description
training_step return Dict[str, float] Dictionary with "actor_loss" and "critic_loss" values

Usage Examples

import ray
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer

trainer = DetachedPPOTrainer.options(
    name="trainer_0", num_gpus=1
).remote(
    experience_maker_holder_name_list=["maker_0", "maker_1"],
    strategy_fn=lambda: my_strategy(),
    model_fn=lambda: (my_actor(), my_critic()),
    train_batch_size=8,
    eps_clip=0.2,
    value_clip=0.4,
    eval_performance=True,
)

# Sync initial model weights to experience makers
ray.get(trainer.sync_models_to_remote_makers.remote())

# Start training
ray.get(trainer.fit.remote(total_steps=1000, update_steps=100, train_epochs=2))

# Save models
ray.get(trainer.strategy_save_actor.remote("./actor_checkpoint", only_rank0=True))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment