Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI RewardModelTrainer

From Leeroopedia


Knowledge Sources
Domains Reinforcement Learning, RLHF, Reward Model, Training
Last Updated 2026-02-09 00:00 GMT

Overview

A supervised learning trainer for reward model training in the ColossalChat RLHF pipeline, using preference data with chosen/rejected response pairs and the LogSigLoss objective.

Description

RewardModelTrainer extends SLTrainer to train a reward model from human preference data. The training loop concatenates chosen and rejected input sequences for parallel forward computation, computes LogSigLoss (or a custom loss function) between chosen and rejected reward scores, and uses gradient accumulation with configurable step count. It tracks chosen rewards, rejected rewards, loss, and accuracy via AccumulativeMeanMeter and logs metrics to TensorBoard and optionally Weights & Biases.

The evaluation method computes the same metrics on an optional eval dataloader and writes results to a text file. Checkpointing is handled via ColossalAI's save_checkpoint utility at configurable intervals, with distributed coordination through DistCoordinator.

Usage

Use RewardModelTrainer to train a reward model on preference datasets with chosen/rejected pairs. Instantiate with the model, Booster, optimizer, learning rate scheduler, tokenizer, and optional loss function. Call fit with train and eval dataloaders.

Code Reference

Source Location

Signature

class RewardModelTrainer(SLTrainer):
    def __init__(
        self,
        model: Any,
        booster: Booster,
        optimizer: Optimizer,
        plugin: Plugin,
        lr_scheduler: _LRScheduler,
        tokenizer: PreTrainedTokenizerBase,
        loss_fn: Optional[Callable] = None,
        max_epochs: int = 1,
        beta: float = 0.1,
        accumulation_steps: int = 1,
        start_epoch: int = 0,
        save_interval: int = 0,
        save_dir: str = None,
        coordinator: DistCoordinator = None,
    ) -> None: ...

    def _before_fit(
        self,
        train_preference_dataloader: DataLoader = None,
        eval_preference_dataloader: DataLoader = None,
        log_dir: Optional[str] = None,
        use_wandb: bool = False,
    ): ...

    def _train(self, epoch): ...
    def _eval(self, epoch): ...

Import

from coati.trainer.rm import RewardModelTrainer

I/O Contract

Inputs

Name Type Required Description
model Any Yes The reward model to train
booster Booster Yes ColossalAI Booster for distributed training
optimizer Optimizer Yes The optimizer for training
plugin Plugin Yes ColossalAI Plugin for parallelism strategy
lr_scheduler _LRScheduler Yes Learning rate scheduler
tokenizer PreTrainedTokenizerBase Yes Tokenizer for encoding
loss_fn Callable No Custom loss function (default LogSigLoss with beta)
max_epochs int No Maximum training epochs (default 1)
beta float No Beta parameter for LogSigLoss (default 0.1)
accumulation_steps int No Gradient accumulation steps (default 1)
start_epoch int No Starting epoch for resuming (default 0)
save_interval int No Checkpoint save interval in steps; 0 disables (default 0)
save_dir str No Directory for saving checkpoints (default None)
coordinator DistCoordinator No Distributed coordinator for logging (default None)

Outputs

Name Type Description
_train return None Trains for one epoch, logging loss, rewards, and accuracy
_eval return None Evaluates the model and writes results to a text file

Usage Examples

from coati.trainer.rm import RewardModelTrainer
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator

trainer = RewardModelTrainer(
    model=reward_model,
    booster=booster,
    optimizer=optimizer,
    plugin=plugin,
    lr_scheduler=lr_scheduler,
    tokenizer=tokenizer,
    max_epochs=3,
    accumulation_steps=4,
    save_interval=500,
    save_dir="./checkpoints/rm",
    coordinator=DistCoordinator(),
)

trainer.fit(
    train_preference_dataloader=train_dataloader,
    eval_preference_dataloader=eval_dataloader,
    log_dir="./logs",
    use_wandb=True,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment