Heuristic:Hiyouga LLaMA Factory Mixed Precision Training Tips
| Knowledge Sources | |
|---|---|
| Domains | Training_Stability, Performance_Optimization |
| Last Updated | 2026-02-06 20:00 GMT |
Overview
Best practices for mixed precision training including dtype selection, layernorm upcasting, and LM head output upcasting for numerical stability.
Description
LLaMA Factory automatically infers the optimal compute dtype based on hardware capability: bf16 is preferred when available, then fp16 for CUDA/NPU, and fp32 as fallback. The framework warns when mixed precision is not enabled and provides several upcasting options for numerical stability. Upcasting layernorm weights to fp32 is recommended for quantized training to prevent gradient instability. Upcasting the LM head output to fp32 prevents loss divergence in low-precision training.
Usage
Use this heuristic whenever training models, especially with quantization or on hardware that supports bf16 (A100, H100, RTX 3090+). Enable bf16 when possible; fall back to fp16 only when bf16 is unavailable. Enable upcast_layernorm for quantized training.
The Insight (Rule of Thumb)
- Action 1: Always enable mixed precision training (
bf16=Truepreferred,fp16=Trueas fallback). - Action 2: Enable
upcast_layernorm=Truewhen training quantized models. - Action 3: Enable
upcast_lmhead_output=Trueif you observe loss instability. - Action 4: For GaLore or APOLLO optimizers, use
pure_bf16=Trueto avoid significant GPU memory increase from mixed precision. - Trade-off: bf16 has less precision than fp32 but 2x throughput and 2x memory reduction. fp16 requires loss scaling and is more prone to overflow. Upcasting adds minor memory overhead but improves stability.
Reasoning
The code provides several warnings that reveal best practices:
Warning when mixed precision is disabled from src/llamafactory/hparams/parser.py:374-375:
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning_rank0("We recommend enable mixed precision training.")
Warning about quantized training from src/llamafactory/hparams/parser.py:371-372:
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
Warning about GaLore/APOLLO with mixed precision from src/llamafactory/hparams/parser.py:377-384:
if (
training_args.do_train
and (finetuning_args.use_galore or finetuning_args.use_apollo)
and not finetuning_args.pure_bf16
):
logger.warning_rank0(
"Using GaLore or APOLLO with mixed precision training may significantly increases GPU memory usage."
)
Dtype inference logic from src/llamafactory/extras/misc.py:241-248:
def infer_optim_dtype(model_dtype):
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
if _is_bf16_available and (model_dtype == torch.bfloat16 or model_dtype is None):
return torch.bfloat16
elif _is_fp16_available:
return torch.float16
else:
return torch.float32
Layernorm upcasting from src/llamafactory/model/model_utils/checkpointing.py:151-155:
if model_args.upcast_layernorm:
logger.info_rank0("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)