Implementation:Hpcaitech ColossalAI DetachedPPOTrainer
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement Learning, Distributed Training, RLHF, PPO |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A Ray remote actor implementing the PPO (Proximal Policy Optimization) algorithm as a detached RLHF trainer for ColossalChat, with support for multiple parallelism strategies and LoRA weight updates.
Description
DetachedPPOTrainer extends DetachedTrainer and is decorated as a @ray.remote actor with concurrency groups for buffer operations, model I/O, and compute. It initializes actor and critic models via factory functions, configures PolicyLoss and ValueLoss with configurable clip coefficients, and selects optimizers based on the training strategy (HybridAdam for Gemini/LowLevelZero strategies, standard Adam otherwise).
The training_step method performs a full PPO update: computing action log probabilities, calculating actor loss via clipped surrogate objective, performing critic loss via value clipping, and executing backward passes and optimizer steps through the strategy. The _update_remote_makers method synchronizes model weights to remote experience makers in a chunked fashion, with optional LoRA-only weight transfer for efficiency.
Usage
Deploy as a named Ray actor in a distributed RLHF setup. Provide factory functions for the strategy and models, along with the names of ExperienceMakerHolder actors. The trainer receives experience data via its buffer methods and runs the PPO training loop when fit is called.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/ray/detached_trainer_ppo.py
- Lines: 1-191
Signature
@ray.remote(
concurrency_groups={
"buffer_length": 1, "buffer_append": 1,
"buffer_sample": 1, "model_io": 1, "compute": 1,
}
)
class DetachedPPOTrainer(DetachedTrainer):
def __init__(
self,
experience_maker_holder_name_list: List[str],
strategy_fn: Callable[[], Strategy],
model_fn: Callable[[], Tuple[Actor, Critic]],
env_info: Dict[str, str] = None,
train_batch_size: int = 8,
buffer_limit: int = 0,
eps_clip: float = 0.2,
value_clip: float = 0.4,
dataloader_pin_memory: bool = True,
callbacks: List[TrainerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
) -> None: ...
def training_step(self, experience: Experience) -> Dict[str, float]: ...
def _update_remote_makers(self, fully_update: bool = False, **config): ...
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None: ...
def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None: ...
Import
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| experience_maker_holder_name_list | List[str] | Yes | Names of remote ExperienceMakerHolder Ray actors |
| strategy_fn | Callable[[], Strategy] | Yes | Factory function returning the training strategy |
| model_fn | Callable[[], Tuple[Actor, Critic]] | Yes | Factory function returning actor and critic models |
| env_info | Dict[str, str] | No | Environment variables for distributed setup (default None) |
| train_batch_size | int | No | Batch size for training (default 8) |
| buffer_limit | int | No | Maximum replay buffer size (default 0) |
| eps_clip | float | No | Policy loss clip coefficient (default 0.2) |
| value_clip | float | No | Value loss clip coefficient (default 0.4) |
| dataloader_pin_memory | bool | No | Pin memory for data loader (default True) |
| callbacks | List[TrainerCallback] | No | Callback instances (default []) |
| eval_performance | bool | No | Enable performance evaluation callback (default False) |
| debug | bool | No | Enable debug logging (default False) |
| update_lora_weights | bool | No | Transfer only LoRA weights to makers (default False) |
Outputs
| Name | Type | Description |
|---|---|---|
| training_step return | Dict[str, float] | Dictionary with "actor_loss" and "critic_loss" values |
Usage Examples
import ray
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
trainer = DetachedPPOTrainer.options(
name="trainer_0", num_gpus=1
).remote(
experience_maker_holder_name_list=["maker_0", "maker_1"],
strategy_fn=lambda: my_strategy(),
model_fn=lambda: (my_actor(), my_critic()),
train_batch_size=8,
eps_clip=0.2,
value_clip=0.4,
eval_performance=True,
)
# Sync initial model weights to experience makers
ray.get(trainer.sync_models_to_remote_makers.remote())
# Start training
ray.get(trainer.fit.remote(total_steps=1000, update_steps=100, train_epochs=2))
# Save models
ray.get(trainer.strategy_save_actor.remote("./actor_checkpoint", only_rank0=True))