Heuristic:Sail sg LongSpec TF32 And FP16 Training
| Knowledge Sources | |
|---|---|
| Domains | Training, Optimization, Mixed_Precision |
| Last Updated | 2026-02-14 06:00 GMT |
Overview
Mixed precision training strategy using TF32 matmul acceleration, FP16 with dynamic loss scaling, BF16 for LoRA/norm layers, and gradient checkpointing with `use_cache=False` to trade compute for VRAM.
Description
LongSpec training combines multiple precision optimization techniques:
- TF32 matmul: Globally enabled via `torch.backends.cuda.matmul.allow_tf32 = True`, using NVIDIA's TensorFloat-32 format for matrix multiplications on Ampere+ GPUs. This provides up to 8x throughput improvement over FP32 with minimal accuracy loss.
- FP16 mixed precision: DeepSpeed FP16 training with dynamic loss scaling (initial scale power 16, window 1000, min scale 1) to prevent underflow/overflow.
- BF16 for specific layers: LoRA adapter layers are cast to BF16 when the compute dtype is BF16, while normalization layers are kept in FP32 for stability.
- Gradient checkpointing: When enabled, `use_cache` is set to `False` and intermediate activations are recomputed during backward pass to reduce VRAM usage.
Usage
Apply this heuristic when training GLIDE draft models to maximize throughput while maintaining training stability. The TF32 setting is hardcoded in the trainer; FP16 is configured via DeepSpeed YAML; gradient checkpointing is a model-level toggle.
The Insight (Rule of Thumb)
- Action 1: Set `torch.backends.cuda.matmul.allow_tf32 = True` globally before training.
- Value: Up to 8x faster matmul on Ampere/Hopper GPUs with negligible accuracy loss.
- Trade-off: TF32 reduces mantissa bits from 23 (FP32) to 10, but this is acceptable for LLM training.
- Action 2: Use DeepSpeed FP16 with `initial_scale_power: 16` and `min_loss_scale: 1`.
- Value: Halves memory usage for activations and weights during forward/backward pass.
- Trade-off: Requires dynamic loss scaling to handle gradient underflow. The `min_loss_scale: 1` prevents the scaler from going to zero.
- Action 3: Keep normalization layers (`'norm' in name`) in FP32.
- Value: Prevents numerical instability in LayerNorm/RMSNorm computations.
- Trade-off: Slightly more memory for norm layer parameters, but these are small.
- Action 4: Set `gradient_checkpointing=True` and `use_cache=False` together.
- Value: Reduces VRAM by not storing intermediate activations or KV cache during training.
- Trade-off: ~20-30% slower training due to recomputation of activations during backward pass.
Reasoning
Large model training (e.g., QwQ-32B target + draft model) is memory-bound. The combination of TF32, FP16, and gradient checkpointing allows training on 8x A100 GPUs by:
- TF32 matmul utilizes Tensor Cores efficiently without explicit dtype casting.
- FP16 halves the storage for activations during the forward pass.
- Gradient checkpointing eliminates the need to store all intermediate activations.
- Keeping norms in FP32 prevents the "NaN loss" problem common in mixed-precision training.
The DeepSpeed loss scaling configuration is conservative (`initial_scale_power: 16` = initial scale of 65536) with a generous window (1000 steps), which balances between overflow protection and training speed.
Code Evidence
TF32 enablement from `trainer_base_ds_mul_fs_tp.py:31`:
torch.backends.cuda.matmul.allow_tf32 = True
Gradient checkpointing + use_cache pattern from `models/utils.py:89-92`:
def enable_gradient_checkpointing(model: PreTrainedModel):
model.config.use_cache = False
model.gradient_checkpointing_enable()
return model
Same pattern in `models/mixin.py:41-43`:
if gradient_checkpointing:
model.config.use_cache = False
model.gradient_checkpointing_enable()
Layer-specific dtype casting from `models/utils.py:73-82`:
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if compute_dtype == torch.bfloat16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if compute_dtype and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
DeepSpeed FP16 config from `qwq_glide_8gpu_slim6b.yaml:166-174`:
fp16:
enabled: True
auto_cast: False
loss_scale: 0
initial_scale_power: 16
loss_scale_window: 1000
hysteresis: 2
consecutive_hysteresis: False
min_loss_scale: 1