Heuristic:Eric mitchell Direct preference optimization FSDP Mixed Precision BFloat16
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Distributed_Training |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Enable FSDP mixed precision with bfloat16 to achieve approximately 50% training speedup and reduced VRAM usage on Ampere+ GPUs.
Description
PyTorch FSDP supports mixed precision training through the `MixedPrecision` policy, which controls the dtype for parameters, gradient reductions, and buffers. When configured with `bfloat16`, the model parameters are cast to bfloat16 during forward and backward passes while gradient accumulation and optimizer state remain in full precision. The DPO codebase implements this in the FSDPTrainer class, controlled by the `model.fsdp_policy_mp` config parameter.
Usage
Use this heuristic when training with FSDPTrainer and you want to reduce VRAM usage and increase training throughput. The README reports approximately 50% speedup with bfloat16 mixed precision. Enable by passing `model.fsdp_policy_mp=bfloat16` on the command line. Only supported for FSDPTrainer; BasicTrainer and TensorParallelTrainer do not implement mixed precision.
The Insight (Rule of Thumb)
- Action: Pass `model.fsdp_policy_mp=bfloat16` when running with FSDPTrainer.
- Value: `bfloat16` (recommended) or `float16` (less tested).
- Trade-off: ~50% speedup and ~50% memory reduction for parameters in exchange for reduced numerical precision during forward/backward passes. Optimizer states remain in FP32.
- Compatibility: Requires Ampere+ GPUs for native bfloat16 support. Only works with FSDPTrainer.
Reasoning
The README explicitly states: "We'll take advantage of FSDP's mixed precision in bfloat16 to speed up training; we usually see about a 50% speedup." BFloat16 is preferred over float16 because it maintains the same dynamic range as float32 (8 exponent bits), avoiding the need for loss scaling.
Code evidence from `trainers.py:457-459`:
mp_dtype = getattr(torch, config.model.fsdp_policy_mp) if config.model.fsdp_policy_mp is not None else None
policy_mp_policy = MixedPrecision(param_dtype=mp_dtype, reduce_dtype=mp_dtype, buffer_dtype=mp_dtype)
self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)
Note that all three MixedPrecision fields (`param_dtype`, `reduce_dtype`, `buffer_dtype`) are set to the same dtype. The logits are cast back to float32 for loss computation in `trainers.py:216`:
all_logits = model(concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32)