Heuristic:Norrrrrrr lyn WAInjectBench NaN Inf Fallback FP32 Recovery
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Debugging, Deep_Learning |
| Last Updated | 2026-02-14 16:00 GMT |
Overview
Automatic recovery from NaN/Inf loss during mixed-precision training by disabling AMP, casting to FP32, and applying a learning rate backoff.
Description
During LLaVA fine-tuning with mixed-precision (FP16/BF16), numerical instability can cause NaN or Inf loss values. Rather than crashing the training run, this heuristic implements a graceful fallback: when a NaN/Inf loss is detected, the training automatically disables Automatic Mixed Precision (AMP), casts the entire model and optimizer states to FP32, and reduces the learning rate by a configurable backoff factor (default 0.5). This is a one-time fallback that prevents wasted compute from a crashed run.
Usage
Use this heuristic when fine-tuning large vision-language models (such as LLaVA-1.5-7b) with mixed-precision training. It is particularly relevant when using FP16 (which is more prone to overflow than BF16) or when training on hardware that does not fully support BF16.
The Insight (Rule of Thumb)
- Action: Monitor `loss` for NaN/Inf at every training step. On detection, disable AMP, cast model + optimizer to FP32, and reduce LR.
- Value: Default LR backoff factor is `0.5` (halves the learning rate). Configurable via `--lr_backoff`.
- Trade-off: Slower training after fallback (FP32 uses more memory and compute than FP16/BF16), but the training run survives instead of crashing.
- One-shot: The fallback triggers at most once (`state.fallback_done` flag prevents repeated triggers).
Reasoning
FP16 has a narrow dynamic range (max ~65,504) that can cause overflow in gradient computations, especially with large models. BF16 has a wider range but is not available on all GPUs. The fallback ensures training robustness across hardware configurations. The LR reduction is critical because the optimizer states have accumulated momentum at the original scale; a sudden precision change can amplify gradient noise, so a conservative LR helps re-stabilize training.
The NaN-detection-then-skip pattern also prevents corrupted gradients from polluting the optimizer state: the step is explicitly skipped via `optim.zero_grad(set_to_none=True)` before the fallback.
Code Evidence
NaN/Inf detection and skip from `train/llava-ft.py:348-355`:
if torch.isnan(loss) or torch.isinf(loss):
print(f"[WARN] Step {global_step}: loss={loss.item()} -> skip & fallback")
optim.zero_grad(set_to_none=True)
maybe_fallback_to_fp32(model, optim, state, args.lr_backoff)
lr_scheduler.step()
global_step += 1
pbar.update(1)
continue
The fallback function from `train/llava-ft.py:66-84`:
def maybe_fallback_to_fp32(model, optimizer, state, lr_backoff):
if state.fallback_done:
return
print("[WARN] Detected NaN/Inf. Disabling AMP and casting model/optimizer to FP32. "
f"Applying LR backoff x{lr_backoff:.3f}.")
state.use_amp = False
state.amp_dtype = None
state.scaler = torch.amp.GradScaler("cuda", enabled=False)
model.float()
for st in optimizer.state.values():
for k, v in list(st.items()):
if torch.is_tensor(v):
st[k] = v.float()
for g in optimizer.param_groups:
g["lr"] = g["lr"] * lr_backoff
state.fallback_done = True
BF16 availability check from `train/llava-ft.py:312-322`:
if args.amp_dtype == "bf16" and torch.cuda.is_bf16_supported():
amp_dtype = torch.bfloat16
use_amp = True
elif args.amp_dtype == "fp16":
amp_dtype = torch.float16
use_amp = True
else:
print("[INFO] AMP not supported as requested dtype. Using FP32.")