Heuristic:Mlfoundations Open flamingo FSDP Manual Wrapping For Mixed Parameters
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, Optimization, Memory_Management |
| Last Updated | 2026-02-08 03:30 GMT |
Overview
Manual double-wrapping FSDP strategy for models with mixed frozen and trainable parameters, working around PyTorch FSDP limitations with parameter grouping and memory management.
Description
OpenFlamingo requires a custom FSDP wrapping strategy because it has a mix of frozen (vision encoder, LM decoder blocks) and unfrozen (perceiver, gated cross-attention layers, input embeddings) parameters. Standard FSDP auto-wrapping fails because all parameters within an FSDP wrapper must have the same `requires_grad`. The solution involves: (1) double-wrapping each submodule with `wrap(wrap(module))` to ensure post-forward/backward hooks fire (non-root modules only), (2) temporarily unfreezing decoder layers to allow FSDP memory management hooks to fire, and (3) excluding unfrozen-but-originally-frozen layers from the optimizer via a custom `exclude_from_optimizer` flag.
Usage
Apply this heuristic when using `--fsdp` for distributed training. It is automatically applied by `model.wrap_fsdp()`. Critical for models with >1B parameters that exceed single-GPU memory with DDP.
The Insight (Rule of Thumb)
- Action: Use `wrap(wrap(module))` for each submodule. Temporarily unfreeze frozen decoder layers. Exclude decoder layers from optimizer via `p.exclude_from_optimizer = True`.
- Value: Enables FSDP training of large multimodal models with mixed frozen/trainable parameters.
- Trade-off: Unfreezing decoder layers means they participate in gradient computation (memory cost) but their gradients are excluded from optimization (no wasted compute). Incompatible with tied embeddings.
- Known Issue: With FSDP + gradient checkpointing, unreasonably large batch sizes (e.g., 100 MMC4 batch size for OPT-125M) degrade downstream performance despite normal-looking training curves.
Reasoning
Three PyTorch FSDP limitations drive this design:
1. Mixed requires_grad: FSDP requires all flat parameters in a wrapper to have the same `requires_grad`. OpenFlamingo has frozen (vision encoder, LM layers) and trainable (perceiver, cross-attention) parameters. Solution: wrap each group individually.
2. Post-hooks on non-root only: As of torch==2.0.1, FSDP's `_post_forward_hook` and `_post_backward_hook` only free gathered parameters if the module is NOT the FSDP root. Solution: double-wrap with `wrap(wrap(module))`.
3. Post-backward hook requires grad: As of torch==2.0.1, FSDP's `_post_backward_hook` only registers if `requires_grad=True`. Without this hook, gathered parameters are never freed, causing OOM. Solution: temporarily unfreeze decoder layers and exclude them from the optimizer instead.
Code Evidence
Double-wrapping and temporary unfreezing from `open_flamingo/src/flamingo.py:256-292`:
# unfreeze the decoder layers
for block in self.lang_encoder.old_decoder_blocks:
block.requires_grad_(True)
# wrap in FSDP
with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
self.perceiver = wrap(wrap(self.perceiver))
self.lang_encoder.old_decoder_blocks = nn.ModuleList(
wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks
)
self.lang_encoder.gated_cross_attn_layers = nn.ModuleList(
wrap(wrap(layer)) if layer is not None else None
for layer in self.lang_encoder.gated_cross_attn_layers
)
...
# exclude the original decoder layers from the optimizer
for block in self.lang_encoder.old_decoder_blocks:
for p in block.parameters():
p.exclude_from_optimizer = True
Known issues documented in `open_flamingo/src/flamingo.py:224-229`:
"""
Known issues:
- Our FSDP strategy is not compatible with tied embeddings. If the LM
embeddings are tied, train with DDP or set --freeze_lm_embeddings.
- With FSDP + gradient ckpting, one can increase the batch size with
seemingly no upper bound. Although the training curves look okay,
we found that downstream performance dramatically degrades if the
batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M).
"""
Hybrid sharding bug warning from `open_flamingo/train/train.py:239-245`:
if args.fsdp and args.fsdp_sharding_strategy == "hybrid":
print(
"Warning: As of torch=2.0.1, the FSDP logic for optim_state_dict() "
"is broken for hybrid sharding."
"To make this method work, we need to modify "
"torch.distributed.fsdp._optim_utils.py"
)