Implementation:Hpcaitech ColossalAI Trainer Callback Base
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement Learning, Callbacks, Training |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Abstract base callback class for the ColossalChat trainer, defining lifecycle hook interfaces for the on-policy reinforcement learning training loop.
Description
Callback is an abstract base class that defines the interface for callbacks used by the OLTrainer (online learning trainer). It provides no-op default implementations for ten lifecycle hooks covering the fit lifecycle (start/end), episode lifecycle (start/end), experience making (start/end), and learning epoch and batch boundaries (start/end). Each hook method takes minimal parameters appropriate to its phase.
Unlike the Ray callback base classes, this Callback is designed for the non-distributed (single-process) trainer and includes hooks for on_make_experience_start/end and on_learn_epoch_start/end instead of the update-oriented hooks in the Ray version.
Usage
Subclass Callback to implement custom logging, checkpointing, or performance monitoring during on-policy RLHF training. Override only the hooks relevant to your use case. Pass instances to the OLTrainer constructor via the callbacks parameter.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/trainer/callbacks/base.py
- Lines: 1-39
Signature
class Callback(ABC):
def on_fit_start(self) -> None: ...
def on_fit_end(self) -> None: ...
def on_episode_start(self, episode: int) -> None: ...
def on_episode_end(self, episode: int) -> None: ...
def on_make_experience_start(self) -> None: ...
def on_make_experience_end(self, experience: Experience) -> None: ...
def on_learn_epoch_start(self, epoch: int) -> None: ...
def on_learn_epoch_end(self, epoch: int) -> None: ...
def on_learn_batch_start(self) -> None: ...
def on_learn_batch_end(self, experience: Experience) -> None: ...
Import
from coati.trainer.callbacks.base import Callback
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| episode | int | No | Episode index passed to on_episode_start/end |
| experience | Experience | No | Experience object passed to on_make_experience_end and on_learn_batch_end |
| epoch | int | No | Learning epoch index passed to on_learn_epoch_start/end |
Outputs
| Name | Type | Description |
|---|---|---|
| return | None | All callback methods return None |
Usage Examples
from coati.trainer.callbacks.base import Callback
from coati.experience_maker import Experience
class WandbCallback(Callback):
def on_episode_start(self, episode: int) -> None:
print(f"Starting episode {episode}")
def on_make_experience_end(self, experience: Experience) -> None:
batch_size = experience.sequences.shape[0]
print(f"Generated {batch_size} experience samples")
def on_fit_end(self) -> None:
print("Training complete")