Principle:Sail sg LongSpec Distributed Training
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Training, Optimization |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Principle for training large-scale models across multiple GPUs using DeepSpeed ZeRO optimization with gradient accumulation and mixed-precision computation.
Description
Distributed Training in the LongSpec pipeline uses DeepSpeed's ZeRO (Zero Redundancy Optimizer) to partition optimizer states, gradients, and optionally model parameters across GPUs. This enables training GLIDE draft models alongside frozen target LLMs that would not fit in single-GPU memory.
The training loop implements:
- DeepSpeed engine initialization wrapping the model with optimizer and learning rate scheduler
- Gradient accumulation over multiple micro-batches before each optimizer step
- Mixed-precision training (FP16/BF16) with dynamic loss scaling
- Multi-file dataset iteration loading different data files per epoch
- Periodic checkpointing with DeepSpeed state saving and draft weight extraction
- WandB logging for training metrics (loss, learning rate, throughput)
The pipeline supports ZeRO Stages 1-3 with configurable optimizer offloading to CPU memory.
Usage
Use this principle when training GLIDE draft models on multi-GPU setups. The training is launched via DeepSpeed's launcher:
deepspeed --num_gpus=8 trainer_base_ds_mul_fs_tp.py +exp=qwq_glide_8gpu_slim6b
ZeRO stage selection depends on model and memory constraints:
- ZeRO-1: Partition optimizer states only (Stage 1 base training)
- ZeRO-3: Partition everything including parameters (Stage 2/3 long-context training with 32k sequences)
Theoretical Basis
DeepSpeed ZeRO eliminates memory redundancy across data-parallel processes:
- Stage 1: Partitions optimizer states (e.g., Adam moments) — reduces memory by ~4x
- Stage 2: Also partitions gradients — reduces memory by ~8x
- Stage 3: Also partitions model parameters — enables training models larger than single-GPU memory
The training loop follows the standard distributed training pattern:
# Abstract training loop (not actual implementation)
for epoch in range(num_epochs):
for batch in dataloader:
loss = model.forward(batch) # Forward pass
model.backward(loss) # Backward pass (DeepSpeed handles)
model.step() # Optimizer step (DeepSpeed handles)
if global_step % save_steps == 0:
save_checkpoint(model) # Periodic checkpoint
Gradient accumulation reduces communication overhead:
Failed to parse (syntax error): {\displaystyle \text{effective\_batch} = \text{per\_gpu\_batch} \times \text{num\_gpus} \times \text{gradient\_accumulation\_steps} }