Implementation:Hpcaitech ColossalAI Checkpoint IO
| Knowledge Sources | |
|---|---|
| Domains | Distributed Training, Checkpointing, RLHF |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Checkpoint save/load utilities for ColossalChat that integrate with the ColossalAI Booster for distributed model persistence.
Description
This module provides functions for saving and loading training checkpoints in the ColossalChat pipeline. save_checkpoint saves the model (sharded), optimizer (sharded), learning rate scheduler, and running state metadata (epoch, step, sample index) into an organized directory structure. load_checkpoint restores all of these components from a checkpoint directory. Both functions leverage the ColossalAI Booster API for distributed-aware saving and loading, and the DistCoordinator ensures that metadata files are only written by the master process. Helper functions load_json and save_json handle JSON file I/O.
Usage
Use these functions in ColossalChat SFT or RLHF training scripts to periodically save checkpoints and resume training from the last saved state.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/utils/ckpt_io.py
- Lines: 1-96
Signature
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
def save_checkpoint(
save_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
) -> None:
def load_checkpoint(
load_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
) -> Tuple[int, int, int]:
Import
from coati.utils.ckpt_io import save_checkpoint, load_checkpoint, load_json, save_json
I/O Contract
Inputs (save_checkpoint)
| Name | Type | Required | Description |
|---|---|---|---|
| save_dir | Union[str, os.PathLike] | Yes | Base directory for saving the checkpoint |
| booster | Booster | Yes | ColossalAI Booster instance for distributed save |
| model | torch.nn.Module | Yes | The model to save |
| optimizer | Optimizer | Yes | The optimizer state to save |
| lr_scheduler | _LRScheduler | Yes | The learning rate scheduler to save |
| epoch | int | Yes | Current epoch number |
| step | int | Yes | Current step number |
| batch_size | int | Yes | Batch size used to compute sample start index |
| coordinator | DistCoordinator | Yes | Distributed coordinator for master-only operations |
Outputs (load_checkpoint)
| Name | Type | Description |
|---|---|---|
| epoch | int | The epoch number from the checkpoint |
| step | int | The step number from the checkpoint |
| sample_start_index | int | The sample start index for resuming data loading |
Usage Examples
from coati.utils.ckpt_io import save_checkpoint, load_checkpoint
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
# Save a checkpoint
save_checkpoint(
save_dir="./checkpoints",
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=1,
step=500,
batch_size=32,
coordinator=coordinator,
)
# Load a checkpoint
epoch, step, sample_start = load_checkpoint(
load_dir="./checkpoints/epoch-1_step-500",
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)