Implementation:Hpcaitech ColossalAI Ray Callback Base
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement Learning, Distributed Training, Callbacks |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Abstract base callback classes for the ColossalChat Ray-based distributed RLHF training pipeline, defining lifecycle hook interfaces for both trainer and experience maker components.
Description
This module provides two abstract base classes: TrainerCallback and MakerCallback. TrainerCallback defines hooks for the trainer lifecycle including fit start/end, episode start/end, epoch start/end, batch start/end, and update start/end. MakerCallback defines hooks for the experience maker lifecycle including loop start/end, make experience start/end, send start/end, and batch start/end.
Both classes inherit from Python's ABC (Abstract Base Class) and provide no-op default implementations for all methods, allowing subclasses to override only the hooks they need.
Usage
Use TrainerCallback when implementing custom callbacks that respond to training lifecycle events in a Ray-based detached trainer. Use MakerCallback when implementing custom callbacks for the experience maker holder. Subclass the appropriate class and override the desired hook methods.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/ray/callbacks/base.py
- Lines: 1-65
Signature
class TrainerCallback(ABC):
def on_fit_start(self) -> None: ...
def on_fit_end(self) -> None: ...
def on_episode_start(self, episode: int) -> None: ...
def on_episode_end(self, episode: int) -> None: ...
def on_epoch_start(self, epoch: int) -> None: ...
def on_epoch_end(self, epoch: int) -> None: ...
def on_batch_start(self) -> None: ...
def on_batch_end(self, metrics: dict, experience: Experience) -> None: ...
def on_update_start(self) -> None: ...
def on_update_end(self) -> None: ...
class MakerCallback(ABC):
def on_loop_start(self) -> None: ...
def on_loop_end(self) -> None: ...
def on_make_experience_start(self) -> None: ...
def on_make_experience_end(self, experience: Experience) -> None: ...
def on_send_start(self) -> None: ...
def on_send_end(self) -> None: ...
def on_batch_start(self) -> None: ...
def on_batch_end(self) -> None: ...
Import
from coati.ray.callbacks.base import TrainerCallback, MakerCallback
I/O Contract
Inputs (TrainerCallback)
| Name | Type | Required | Description |
|---|---|---|---|
| episode | int | No | Episode index passed to on_episode_start/end |
| epoch | int | No | Epoch index passed to on_epoch_start/end |
| metrics | dict | No | Training metrics dict passed to on_batch_end |
| experience | Experience | No | Experience object passed to on_batch_end |
Inputs (MakerCallback)
| Name | Type | Required | Description |
|---|---|---|---|
| experience | Experience | No | Experience object passed to on_make_experience_end |
Outputs
| Name | Type | Description |
|---|---|---|
| return | None | All callback methods return None |
Usage Examples
from coati.ray.callbacks.base import TrainerCallback
from coati.experience_maker import Experience
class LoggingCallback(TrainerCallback):
def on_batch_end(self, metrics: dict, experience: Experience) -> None:
print(f"Batch metrics: {metrics}")
def on_fit_end(self) -> None:
print("Training complete.")