Implementation:Hpcaitech ColossalAI DPOTrainer
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Reinforcement_Learning |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for executing Direct Preference Optimization training, provided by ColossalChat.
Description
DPOTrainer implements the DPO alignment training loop. It maintains a policy model and a frozen reference model, computing log probabilities for both chosen and rejected sequences to calculate the DPO loss. It supports gradient accumulation, pipeline parallelism, periodic evaluation with reward accuracy metrics, and checkpoint saving.
Usage
Use after configuring the Booster, loading both policy and reference models, and preparing preference dataloaders.
Code Reference
Source Location
- Repository: ColossalAI
- File: applications/ColossalChat/coati/trainer/dpo.py
- Lines: 29-643
Signature
class DPOTrainer(SLTrainer):
def __init__(
self,
actor: Any,
ref_model: Any,
booster: Booster,
actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
beta: float = 0.1,
gamma: float = 0.0,
length_normalization: bool = False,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
"""
Args:
actor: Policy model to train
ref_model: Frozen reference model for KL divergence
booster: ColossalAI Booster
actor_optim: Optimizer for policy model
plugin: Distributed plugin
actor_lr_scheduler: LR scheduler
tokenizer: For decoding during eval
beta: DPO temperature parameter (default: 0.1)
gamma: Label smoothing parameter (default: 0.0)
length_normalization: Normalize loss by response length
apply_loss_mask: Mask non-response tokens
accumulation_steps: Gradient accumulation
"""
def fit(
self,
train_preference_dataloader: DataLoader,
eval_preference_dataloader: Optional[DataLoader] = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
) -> None:
"""Run the full DPO training loop."""
Import
from coati.trainer import DPOTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| actor | nn.Module | Yes | Boosted policy model |
| ref_model | nn.Module | Yes | Frozen reference model |
| booster | Booster | Yes | ColossalAI Booster |
| train_preference_dataloader | DataLoader | Yes | Preference pair batches |
| beta | float | No | DPO temperature (default: 0.1) |
| gamma | float | No | Label smoothing (default: 0.0) |
| length_normalization | bool | No | Normalize by response length (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model | nn.Module | DPO-aligned policy model |
| Logs | Dict | chosen_reward, rejected_reward, reward_accuracy, loss |
| Checkpoints | Files | Periodic model/optimizer/scheduler checkpoints |
Usage Examples
from coati.trainer import DPOTrainer
trainer = DPOTrainer(
actor=model,
ref_model=ref_model,
booster=booster,
actor_optim=optimizer,
plugin=plugin,
actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer,
max_epochs=1,
beta=0.1,
apply_loss_mask=True,
accumulation_steps=4,
save_interval=500,
save_dir="./dpo_checkpoints",
coordinator=coordinator,
)
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=eval_dataloader,
log_dir="./dpo_logs",
use_wandb=True,
)
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment