Implementation:Hpcaitech ColossalAI Booster
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Optimization |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for orchestrating distributed training with pluggable parallelism strategies, provided by ColossalAI.
Description
The Booster class is ColossalAI's primary training orchestrator. It accepts a Plugin that defines the parallelism strategy, then wraps the model, optimizer, dataloader, and LR scheduler for distributed execution via boost(). All subsequent training operations (backward pass, pipeline execution, checkpointing) are performed through the Booster interface.
Usage
Use this after initializing the distributed environment and before the training loop. Create a Booster with the desired plugin, call boost() to wrap all training components, then use backward() or execute_pipeline() during training.
Code Reference
Source Location
- Repository: ColossalAI
- File: colossalai/booster/booster.py
- Lines: 33-434
Signature
class Booster:
def __init__(
self,
device: Optional[str] = None,
mixed_precision: Optional[Union[MixedPrecision, str]] = None,
plugin: Optional[Plugin] = None,
) -> None:
"""
Args:
device: Device to run training (default: None -> 'cuda')
mixed_precision: Mixed precision setting ('fp16', 'bf16', 'fp8')
plugin: Plugin for distributed training strategy
"""
def boost(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""Wrap model/optimizer/dataloader for distributed training."""
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
"""Execute backward pass with distributed gradient handling."""
def execute_pipeline(
self,
data_iter: Iterator,
model: nn.Module,
criterion: Callable,
optimizer: Optional[Optimizer] = None,
return_loss: bool = True,
return_outputs: bool = False,
) -> Dict[str, Any]:
"""Execute forward and backward with pipeline parallelism."""
def save_model(
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""Save model checkpoint (optionally sharded)."""
Import
from colossalai.booster import Booster
from colossalai.booster.plugin import (
GeminiPlugin,
HybridParallelPlugin,
LowLevelZeroPlugin,
TorchDDPPlugin,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| plugin | Plugin | Yes | Parallelism strategy (DDP, ZeRO, Gemini, HybridParallel) |
| model | nn.Module | Yes | PyTorch model to wrap |
| optimizer | Optimizer | No | Optimizer to wrap |
| dataloader | DataLoader | No | DataLoader to wrap with distributed sampler |
| lr_scheduler | LRScheduler | No | Learning rate scheduler |
Outputs
| Name | Type | Description |
|---|---|---|
| boost() returns | Tuple | Wrapped (model, optimizer, criterion, dataloader, lr_scheduler) |
| backward() | None | Executes distributed backward pass |
| execute_pipeline() | Dict | Returns {"loss": Tensor, "outputs": Any} for pipeline parallelism |
| save_model() | None | Saves checkpoint to disk |
Usage Examples
Standard Training with ZeRO-2
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
# Create plugin and booster
plugin = LowLevelZeroPlugin(stage=2, precision="bf16")
booster = Booster(plugin=plugin)
# Wrap training components
model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
lr_scheduler=lr_scheduler,
)
# Training loop
for batch in dataloader:
outputs = model(**batch)
loss = criterion(outputs)
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
3D Parallelism
from colossalai.booster.plugin import HybridParallelPlugin
plugin = HybridParallelPlugin(
tp_size=2, # Tensor parallel across 2 GPUs
pp_size=2, # Pipeline parallel across 2 stages
sp_size=1, # Sequence parallel disabled
zero_stage=1, # ZeRO stage 1 for data parallel
precision="bf16",
microbatch_size=4,
)
booster = Booster(plugin=plugin)