Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Hpcaitech ColossalAI SFTTrainer

From Leeroopedia


Knowledge Sources
Domains NLP, Training
Last Updated 2026-02-09 00:00 GMT

Overview

Concrete tool for executing supervised fine-tuning of language models, provided by ColossalChat.

Description

SFTTrainer extends ColossalChat's SLTrainer base class to implement the SFT training loop. It handles gradient accumulation, loss masking, pipeline-parallel execution, periodic evaluation, checkpoint saving, and logging to TensorBoard/WandB.

Usage

Use this after configuring the Booster, optimizer, and dataloaders. Create an SFTTrainer and call fit() to start training.

Code Reference

Source Location

  • Repository: ColossalAI
  • File: applications/ColossalChat/coati/trainer/sft.py
  • Lines: 25-249

Signature

class SFTTrainer(SLTrainer):
    def __init__(
        self,
        model,
        booster: Booster,
        optim: Optimizer,
        lr_scheduler: _LRScheduler,
        max_epochs: int = 2,
        plugin: Plugin = None,
        accumulation_steps: int = 8,
        apply_loss_mask: bool = True,
        start_epoch: int = 0,
        save_interval: int = None,
        save_dir: str = None,
        coordinator: Optional[DistCoordinator] = None,
    ) -> None:
        """
        Args:
            model: Model to train (boosted)
            booster: ColossalAI Booster instance
            optim: Optimizer
            lr_scheduler: Learning rate scheduler
            max_epochs: Number of training epochs
            plugin: Distributed plugin
            accumulation_steps: Gradient accumulation steps
            apply_loss_mask: Mask non-response tokens in loss
            save_interval: Save checkpoint every N steps
            save_dir: Checkpoint output directory
            coordinator: Distributed coordinator
        """

    def fit(
        self,
        train_dataloader: DataLoader,
        eval_dataloader: Optional[DataLoader] = None,
        log_dir: Optional[str] = None,
        use_wandb: bool = False,
    ) -> None:
        """Run the full training loop."""

Import

from coati.trainer import SFTTrainer

I/O Contract

Inputs

Name Type Required Description
model nn.Module Yes Boosted model from Booster.boost()
booster Booster Yes ColossalAI Booster instance
optim Optimizer Yes Wrapped optimizer
lr_scheduler _LRScheduler Yes Learning rate scheduler
train_dataloader DataLoader Yes Training data with input_ids, attention_mask, labels
eval_dataloader DataLoader No Evaluation data
accumulation_steps int No Gradient accumulation steps (default: 8)
apply_loss_mask bool No Mask prompt tokens in loss (default: True)

Outputs

Name Type Description
Trained model nn.Module Model with updated weights (in-memory)
Checkpoints Files Periodic model/optimizer/scheduler checkpoints
Logs TensorBoard/WandB Training loss, learning rate, evaluation metrics

Usage Examples

Complete SFT Training

from coati.trainer import SFTTrainer

trainer = SFTTrainer(
    model=model,
    booster=booster,
    optim=optimizer,
    lr_scheduler=lr_scheduler,
    max_epochs=3,
    plugin=plugin,
    accumulation_steps=4,
    apply_loss_mask=True,
    save_interval=1000,
    save_dir="./checkpoints",
    coordinator=coordinator,
)

trainer.fit(
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    log_dir="./logs",
    use_wandb=True,
)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment