Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Norrrrrrr lyn WAInjectBench NaN Inf Fallback FP32 Recovery

From Leeroopedia
Knowledge Sources
Domains Optimization, Debugging, Deep_Learning
Last Updated 2026-02-14 16:00 GMT

Overview

Automatic recovery from NaN/Inf loss during mixed-precision training by disabling AMP, casting to FP32, and applying a learning rate backoff.

Description

During LLaVA fine-tuning with mixed-precision (FP16/BF16), numerical instability can cause NaN or Inf loss values. Rather than crashing the training run, this heuristic implements a graceful fallback: when a NaN/Inf loss is detected, the training automatically disables Automatic Mixed Precision (AMP), casts the entire model and optimizer states to FP32, and reduces the learning rate by a configurable backoff factor (default 0.5). This is a one-time fallback that prevents wasted compute from a crashed run.

Usage

Use this heuristic when fine-tuning large vision-language models (such as LLaVA-1.5-7b) with mixed-precision training. It is particularly relevant when using FP16 (which is more prone to overflow than BF16) or when training on hardware that does not fully support BF16.

The Insight (Rule of Thumb)

  • Action: Monitor `loss` for NaN/Inf at every training step. On detection, disable AMP, cast model + optimizer to FP32, and reduce LR.
  • Value: Default LR backoff factor is `0.5` (halves the learning rate). Configurable via `--lr_backoff`.
  • Trade-off: Slower training after fallback (FP32 uses more memory and compute than FP16/BF16), but the training run survives instead of crashing.
  • One-shot: The fallback triggers at most once (`state.fallback_done` flag prevents repeated triggers).

Reasoning

FP16 has a narrow dynamic range (max ~65,504) that can cause overflow in gradient computations, especially with large models. BF16 has a wider range but is not available on all GPUs. The fallback ensures training robustness across hardware configurations. The LR reduction is critical because the optimizer states have accumulated momentum at the original scale; a sudden precision change can amplify gradient noise, so a conservative LR helps re-stabilize training.

The NaN-detection-then-skip pattern also prevents corrupted gradients from polluting the optimizer state: the step is explicitly skipped via `optim.zero_grad(set_to_none=True)` before the fallback.

Code Evidence

NaN/Inf detection and skip from `train/llava-ft.py:348-355`:

if torch.isnan(loss) or torch.isinf(loss):
    print(f"[WARN] Step {global_step}: loss={loss.item()} -> skip & fallback")
    optim.zero_grad(set_to_none=True)
    maybe_fallback_to_fp32(model, optim, state, args.lr_backoff)
    lr_scheduler.step()
    global_step += 1
    pbar.update(1)
    continue

The fallback function from `train/llava-ft.py:66-84`:

def maybe_fallback_to_fp32(model, optimizer, state, lr_backoff):
    if state.fallback_done:
        return
    print("[WARN] Detected NaN/Inf. Disabling AMP and casting model/optimizer to FP32. "
          f"Applying LR backoff x{lr_backoff:.3f}.")
    state.use_amp = False
    state.amp_dtype = None
    state.scaler = torch.amp.GradScaler("cuda", enabled=False)
    model.float()
    for st in optimizer.state.values():
        for k, v in list(st.items()):
            if torch.is_tensor(v):
                st[k] = v.float()
    for g in optimizer.param_groups:
        g["lr"] = g["lr"] * lr_backoff
    state.fallback_done = True

BF16 availability check from `train/llava-ft.py:312-322`:

if args.amp_dtype == "bf16" and torch.cuda.is_bf16_supported():
    amp_dtype = torch.bfloat16
    use_amp = True
elif args.amp_dtype == "fp16":
    amp_dtype = torch.float16
    use_amp = True
else:
    print("[INFO] AMP not supported as requested dtype. Using FP32.")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment