Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Sail sg LongSpec TF32 And FP16 Training

From Leeroopedia
Knowledge Sources
Domains Training, Optimization, Mixed_Precision
Last Updated 2026-02-14 06:00 GMT

Overview

Mixed precision training strategy using TF32 matmul acceleration, FP16 with dynamic loss scaling, BF16 for LoRA/norm layers, and gradient checkpointing with `use_cache=False` to trade compute for VRAM.

Description

LongSpec training combines multiple precision optimization techniques:

  • TF32 matmul: Globally enabled via `torch.backends.cuda.matmul.allow_tf32 = True`, using NVIDIA's TensorFloat-32 format for matrix multiplications on Ampere+ GPUs. This provides up to 8x throughput improvement over FP32 with minimal accuracy loss.
  • FP16 mixed precision: DeepSpeed FP16 training with dynamic loss scaling (initial scale power 16, window 1000, min scale 1) to prevent underflow/overflow.
  • BF16 for specific layers: LoRA adapter layers are cast to BF16 when the compute dtype is BF16, while normalization layers are kept in FP32 for stability.
  • Gradient checkpointing: When enabled, `use_cache` is set to `False` and intermediate activations are recomputed during backward pass to reduce VRAM usage.

Usage

Apply this heuristic when training GLIDE draft models to maximize throughput while maintaining training stability. The TF32 setting is hardcoded in the trainer; FP16 is configured via DeepSpeed YAML; gradient checkpointing is a model-level toggle.

The Insight (Rule of Thumb)

  • Action 1: Set `torch.backends.cuda.matmul.allow_tf32 = True` globally before training.
  • Value: Up to 8x faster matmul on Ampere/Hopper GPUs with negligible accuracy loss.
  • Trade-off: TF32 reduces mantissa bits from 23 (FP32) to 10, but this is acceptable for LLM training.
  • Action 2: Use DeepSpeed FP16 with `initial_scale_power: 16` and `min_loss_scale: 1`.
  • Value: Halves memory usage for activations and weights during forward/backward pass.
  • Trade-off: Requires dynamic loss scaling to handle gradient underflow. The `min_loss_scale: 1` prevents the scaler from going to zero.
  • Action 3: Keep normalization layers (`'norm' in name`) in FP32.
  • Value: Prevents numerical instability in LayerNorm/RMSNorm computations.
  • Trade-off: Slightly more memory for norm layer parameters, but these are small.
  • Action 4: Set `gradient_checkpointing=True` and `use_cache=False` together.
  • Value: Reduces VRAM by not storing intermediate activations or KV cache during training.
  • Trade-off: ~20-30% slower training due to recomputation of activations during backward pass.

Reasoning

Large model training (e.g., QwQ-32B target + draft model) is memory-bound. The combination of TF32, FP16, and gradient checkpointing allows training on 8x A100 GPUs by:

  • TF32 matmul utilizes Tensor Cores efficiently without explicit dtype casting.
  • FP16 halves the storage for activations during the forward pass.
  • Gradient checkpointing eliminates the need to store all intermediate activations.
  • Keeping norms in FP32 prevents the "NaN loss" problem common in mixed-precision training.

The DeepSpeed loss scaling configuration is conservative (`initial_scale_power: 16` = initial scale of 65536) with a generous window (1000 steps), which balances between overflow protection and training speed.

Code Evidence

TF32 enablement from `trainer_base_ds_mul_fs_tp.py:31`:

torch.backends.cuda.matmul.allow_tf32 = True

Gradient checkpointing + use_cache pattern from `models/utils.py:89-92`:

def enable_gradient_checkpointing(model: PreTrainedModel):
    model.config.use_cache = False
    model.gradient_checkpointing_enable()
    return model

Same pattern in `models/mixin.py:41-43`:

if gradient_checkpointing:
    model.config.use_cache = False
    model.gradient_checkpointing_enable()

Layer-specific dtype casting from `models/utils.py:73-82`:

for name, module in model.named_modules():
    if isinstance(module, LoraLayer):
        if compute_dtype == torch.bfloat16:
            module = module.to(torch.bfloat16)
    if 'norm' in name:
        module = module.to(torch.float32)
    if 'lm_head' in name or 'embed_tokens' in name:
        if hasattr(module, 'weight'):
            if compute_dtype and module.weight.dtype == torch.float32:
                module = module.to(torch.bfloat16)

DeepSpeed FP16 config from `qwq_glide_8gpu_slim6b.yaml:166-174`:

fp16:
  enabled: True
  auto_cast: False
  loss_scale: 0
  initial_scale_power: 16
  loss_scale_window: 1000
  hysteresis: 2
  consecutive_hysteresis: False
  min_loss_scale: 1

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment