Implementation:Hpcaitech ColossalAI SFTTrainer
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for executing supervised fine-tuning of language models, provided by ColossalChat.
Description
SFTTrainer extends ColossalChat's SLTrainer base class to implement the SFT training loop. It handles gradient accumulation, loss masking, pipeline-parallel execution, periodic evaluation, checkpoint saving, and logging to TensorBoard/WandB.
Usage
Use this after configuring the Booster, optimizer, and dataloaders. Create an SFTTrainer and call fit() to start training.
Code Reference
Source Location
- Repository: ColossalAI
- File: applications/ColossalChat/coati/trainer/sft.py
- Lines: 25-249
Signature
class SFTTrainer(SLTrainer):
def __init__(
self,
model,
booster: Booster,
optim: Optimizer,
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
plugin: Plugin = None,
accumulation_steps: int = 8,
apply_loss_mask: bool = True,
start_epoch: int = 0,
save_interval: int = None,
save_dir: str = None,
coordinator: Optional[DistCoordinator] = None,
) -> None:
"""
Args:
model: Model to train (boosted)
booster: ColossalAI Booster instance
optim: Optimizer
lr_scheduler: Learning rate scheduler
max_epochs: Number of training epochs
plugin: Distributed plugin
accumulation_steps: Gradient accumulation steps
apply_loss_mask: Mask non-response tokens in loss
save_interval: Save checkpoint every N steps
save_dir: Checkpoint output directory
coordinator: Distributed coordinator
"""
def fit(
self,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
) -> None:
"""Run the full training loop."""
Import
from coati.trainer import SFTTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | Boosted model from Booster.boost() |
| booster | Booster | Yes | ColossalAI Booster instance |
| optim | Optimizer | Yes | Wrapped optimizer |
| lr_scheduler | _LRScheduler | Yes | Learning rate scheduler |
| train_dataloader | DataLoader | Yes | Training data with input_ids, attention_mask, labels |
| eval_dataloader | DataLoader | No | Evaluation data |
| accumulation_steps | int | No | Gradient accumulation steps (default: 8) |
| apply_loss_mask | bool | No | Mask prompt tokens in loss (default: True) |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model | nn.Module | Model with updated weights (in-memory) |
| Checkpoints | Files | Periodic model/optimizer/scheduler checkpoints |
| Logs | TensorBoard/WandB | Training loss, learning rate, evaluation metrics |
Usage Examples
Complete SFT Training
from coati.trainer import SFTTrainer
trainer = SFTTrainer(
model=model,
booster=booster,
optim=optimizer,
lr_scheduler=lr_scheduler,
max_epochs=3,
plugin=plugin,
accumulation_steps=4,
apply_loss_mask=True,
save_interval=1000,
save_dir="./checkpoints",
coordinator=coordinator,
)
trainer.fit(
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
log_dir="./logs",
use_wandb=True,
)
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment