Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Eric mitchell Direct preference optimization FSDP Mixed Precision BFloat16

From Leeroopedia




Knowledge Sources
Domains Optimization, Distributed_Training
Last Updated 2026-02-08 02:00 GMT

Overview

Enable FSDP mixed precision with bfloat16 to achieve approximately 50% training speedup and reduced VRAM usage on Ampere+ GPUs.

Description

PyTorch FSDP supports mixed precision training through the `MixedPrecision` policy, which controls the dtype for parameters, gradient reductions, and buffers. When configured with `bfloat16`, the model parameters are cast to bfloat16 during forward and backward passes while gradient accumulation and optimizer state remain in full precision. The DPO codebase implements this in the FSDPTrainer class, controlled by the `model.fsdp_policy_mp` config parameter.

Usage

Use this heuristic when training with FSDPTrainer and you want to reduce VRAM usage and increase training throughput. The README reports approximately 50% speedup with bfloat16 mixed precision. Enable by passing `model.fsdp_policy_mp=bfloat16` on the command line. Only supported for FSDPTrainer; BasicTrainer and TensorParallelTrainer do not implement mixed precision.

The Insight (Rule of Thumb)

  • Action: Pass `model.fsdp_policy_mp=bfloat16` when running with FSDPTrainer.
  • Value: `bfloat16` (recommended) or `float16` (less tested).
  • Trade-off: ~50% speedup and ~50% memory reduction for parameters in exchange for reduced numerical precision during forward/backward passes. Optimizer states remain in FP32.
  • Compatibility: Requires Ampere+ GPUs for native bfloat16 support. Only works with FSDPTrainer.

Reasoning

The README explicitly states: "We'll take advantage of FSDP's mixed precision in bfloat16 to speed up training; we usually see about a 50% speedup." BFloat16 is preferred over float16 because it maintains the same dynamic range as float32 (8 exponent bits), avoiding the need for loss scaling.

Code evidence from `trainers.py:457-459`:

mp_dtype = getattr(torch, config.model.fsdp_policy_mp) if config.model.fsdp_policy_mp is not None else None
policy_mp_policy = MixedPrecision(param_dtype=mp_dtype, reduce_dtype=mp_dtype, buffer_dtype=mp_dtype)
self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)

Note that all three MixedPrecision fields (`param_dtype`, `reduce_dtype`, `buffer_dtype`) are set to the same dtype. The logits are cast back to float32 for loss computation in `trainers.py:216`:

all_logits = model(concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment