Implementation:Hpcaitech ColossalAI RewardModelTrainer
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement Learning, RLHF, Reward Model, Training |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A supervised learning trainer for reward model training in the ColossalChat RLHF pipeline, using preference data with chosen/rejected response pairs and the LogSigLoss objective.
Description
RewardModelTrainer extends SLTrainer to train a reward model from human preference data. The training loop concatenates chosen and rejected input sequences for parallel forward computation, computes LogSigLoss (or a custom loss function) between chosen and rejected reward scores, and uses gradient accumulation with configurable step count. It tracks chosen rewards, rejected rewards, loss, and accuracy via AccumulativeMeanMeter and logs metrics to TensorBoard and optionally Weights & Biases.
The evaluation method computes the same metrics on an optional eval dataloader and writes results to a text file. Checkpointing is handled via ColossalAI's save_checkpoint utility at configurable intervals, with distributed coordination through DistCoordinator.
Usage
Use RewardModelTrainer to train a reward model on preference datasets with chosen/rejected pairs. Instantiate with the model, Booster, optimizer, learning rate scheduler, tokenizer, and optional loss function. Call fit with train and eval dataloaders.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/trainer/rm.py
- Lines: 1-247
Signature
class RewardModelTrainer(SLTrainer):
def __init__(
self,
model: Any,
booster: Booster,
optimizer: Optimizer,
plugin: Plugin,
lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
loss_fn: Optional[Callable] = None,
max_epochs: int = 1,
beta: float = 0.1,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None: ...
def _before_fit(
self,
train_preference_dataloader: DataLoader = None,
eval_preference_dataloader: DataLoader = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
): ...
def _train(self, epoch): ...
def _eval(self, epoch): ...
Import
from coati.trainer.rm import RewardModelTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | Any | Yes | The reward model to train |
| booster | Booster | Yes | ColossalAI Booster for distributed training |
| optimizer | Optimizer | Yes | The optimizer for training |
| plugin | Plugin | Yes | ColossalAI Plugin for parallelism strategy |
| lr_scheduler | _LRScheduler | Yes | Learning rate scheduler |
| tokenizer | PreTrainedTokenizerBase | Yes | Tokenizer for encoding |
| loss_fn | Callable | No | Custom loss function (default LogSigLoss with beta) |
| max_epochs | int | No | Maximum training epochs (default 1) |
| beta | float | No | Beta parameter for LogSigLoss (default 0.1) |
| accumulation_steps | int | No | Gradient accumulation steps (default 1) |
| start_epoch | int | No | Starting epoch for resuming (default 0) |
| save_interval | int | No | Checkpoint save interval in steps; 0 disables (default 0) |
| save_dir | str | No | Directory for saving checkpoints (default None) |
| coordinator | DistCoordinator | No | Distributed coordinator for logging (default None) |
Outputs
| Name | Type | Description |
|---|---|---|
| _train return | None | Trains for one epoch, logging loss, rewards, and accuracy |
| _eval return | None | Evaluates the model and writes results to a text file |
Usage Examples
from coati.trainer.rm import RewardModelTrainer
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
trainer = RewardModelTrainer(
model=reward_model,
booster=booster,
optimizer=optimizer,
plugin=plugin,
lr_scheduler=lr_scheduler,
tokenizer=tokenizer,
max_epochs=3,
accumulation_steps=4,
save_interval=500,
save_dir="./checkpoints/rm",
coordinator=DistCoordinator(),
)
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=eval_dataloader,
log_dir="./logs",
use_wandb=True,
)