Implementation:Hpcaitech ColossalAI Save Checkpoint SFT
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Infrastructure |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for saving distributed training checkpoints in ColossalChat's SFT/DPO training pipelines.
Description
The save_checkpoint() function coordinates checkpoint saving across all distributed ranks. It creates a directory structure with sharded model weights, optimizer states, LR scheduler state, and training metadata. The function uses ColossalAI's Booster to handle the complexity of saving model and optimizer states that may be partitioned across GPUs.
Usage
Called periodically during training (at save_interval steps) and at the end of training. The coordinator ensures only rank 0 creates directories and writes metadata.
Code Reference
Source Location
- Repository: ColossalAI
- File: applications/ColossalChat/coati/utils/ckpt_io.py
- Lines: 36-69
Signature
def save_checkpoint(
save_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
) -> None:
"""
Save checkpoint including model, optimizer, lr_scheduler, and running states.
Args:
save_dir: Root checkpoint directory
booster: ColossalAI Booster for distributed saving
model: Model to save
optimizer: Optimizer state to save
lr_scheduler: LR scheduler state to save
epoch: Current epoch number
step: Current training step
batch_size: Batch size (for computing sample_start_index)
coordinator: Distributed coordinator for rank-aware operations
"""
Import
from coati.utils import save_checkpoint
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| save_dir | str | Yes | Root directory for checkpoint output |
| booster | Booster | Yes | ColossalAI Booster for distributed model saving |
| model | nn.Module | Yes | Model to checkpoint |
| optimizer | Optimizer | Yes | Optimizer state to checkpoint |
| lr_scheduler | _LRScheduler | Yes | LR scheduler state to checkpoint |
| epoch | int | Yes | Current epoch number |
| step | int | Yes | Current training step |
| batch_size | int | Yes | Batch size for sample index computation |
| coordinator | DistCoordinator | Yes | Rank-aware coordinator |
Outputs
| Name | Type | Description |
|---|---|---|
| Checkpoint directory | Files | epoch-{N}_step-{M}/ with modeling/, optimizer/, lr_scheduler, running_states.json |
Usage Examples
Save During Training
from coati.utils import save_checkpoint
# Called inside training loop at save_interval
if step % save_interval == 0:
save_checkpoint(
save_dir="./checkpoints",
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step,
batch_size=batch_size,
coordinator=coordinator,
)