Heuristic:Mlfoundations Open flamingo Gradient Clipping Max Norm
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning |
| Last Updated | 2026-02-08 03:30 GMT |
Overview
Gradient norm clipping at max_norm=1.0 to stabilize training, with separate implementations for FSDP (per-submodule clipping) and DDP (whole-model clipping).
Description
OpenFlamingo clips gradient norms to a maximum of 1.0 during training. This prevents gradient explosions that can cause NaN losses or unstable training dynamics. The implementation differs between FSDP and DDP modes: FSDP clips gradients per-submodule (perceiver, gated cross-attention layers, input embeddings), while DDP clips across the entire model parameters at once.
Usage
Apply this heuristic during any Distributed Training run. It is automatically applied in the training loop after the backward pass and before the optimizer step.
The Insight (Rule of Thumb)
- Action: Clip gradient norm to 1.0 after backward pass.
- Value: `max_norm=1.0`
- Trade-off: May slightly slow convergence by limiting large gradient updates, but prevents training instabilities and NaN losses.
- Note: FSDP clips per-submodule rather than globally. Per the developers, "At least for OPT-125M, this didn't seem to make a difference in performance."
Reasoning
Large gradient norms commonly occur in multimodal training when vision and language loss scales differ or during early training. Clipping to 1.0 is a standard practice that prevents catastrophic updates while allowing normal-sized gradients to pass through unmodified. The FSDP-specific implementation is a necessity of the FSDP API, which only provides `clip_grad_norm_` on individual FSDP-wrapped modules.
Code Evidence
Gradient clipping from `open_flamingo/train/train_utils.py:198-208`:
# clip gradient norm
if args.fsdp:
"""
The way we clip gradients with FSDP is different than the non-FSDP case,
because during FSDP, gradient norms are computed over certain submodules,
rather than the entire model.
At least for OPT-125M, this didn't seem to make a difference in performance.
"""
model.clip_grad_norm_(1.0)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
FSDP per-submodule clip function from `open_flamingo/src/flamingo.py:294-301`:
def clip_grad_norm_(max_norm):
self.perceiver.clip_grad_norm_(max_norm)
for layer in self.lang_encoder.gated_cross_attn_layers:
if layer is not None:
layer.clip_grad_norm_(max_norm)
self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm)
self.clip_grad_norm_ = clip_grad_norm_