Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Heuristic:Mlfoundations Open flamingo FSDP Manual Wrapping For Mixed Parameters

From Leeroopedia



Knowledge Sources
Domains Distributed_Training, Optimization, Memory_Management
Last Updated 2026-02-08 03:30 GMT

Overview

Manual double-wrapping FSDP strategy for models with mixed frozen and trainable parameters, working around PyTorch FSDP limitations with parameter grouping and memory management.

Description

OpenFlamingo requires a custom FSDP wrapping strategy because it has a mix of frozen (vision encoder, LM decoder blocks) and unfrozen (perceiver, gated cross-attention layers, input embeddings) parameters. Standard FSDP auto-wrapping fails because all parameters within an FSDP wrapper must have the same `requires_grad`. The solution involves: (1) double-wrapping each submodule with `wrap(wrap(module))` to ensure post-forward/backward hooks fire (non-root modules only), (2) temporarily unfreezing decoder layers to allow FSDP memory management hooks to fire, and (3) excluding unfrozen-but-originally-frozen layers from the optimizer via a custom `exclude_from_optimizer` flag.

Usage

Apply this heuristic when using `--fsdp` for distributed training. It is automatically applied by `model.wrap_fsdp()`. Critical for models with >1B parameters that exceed single-GPU memory with DDP.

The Insight (Rule of Thumb)

  • Action: Use `wrap(wrap(module))` for each submodule. Temporarily unfreeze frozen decoder layers. Exclude decoder layers from optimizer via `p.exclude_from_optimizer = True`.
  • Value: Enables FSDP training of large multimodal models with mixed frozen/trainable parameters.
  • Trade-off: Unfreezing decoder layers means they participate in gradient computation (memory cost) but their gradients are excluded from optimization (no wasted compute). Incompatible with tied embeddings.
  • Known Issue: With FSDP + gradient checkpointing, unreasonably large batch sizes (e.g., 100 MMC4 batch size for OPT-125M) degrade downstream performance despite normal-looking training curves.

Reasoning

Three PyTorch FSDP limitations drive this design:

1. Mixed requires_grad: FSDP requires all flat parameters in a wrapper to have the same `requires_grad`. OpenFlamingo has frozen (vision encoder, LM layers) and trainable (perceiver, cross-attention) parameters. Solution: wrap each group individually.

2. Post-hooks on non-root only: As of torch==2.0.1, FSDP's `_post_forward_hook` and `_post_backward_hook` only free gathered parameters if the module is NOT the FSDP root. Solution: double-wrap with `wrap(wrap(module))`.

3. Post-backward hook requires grad: As of torch==2.0.1, FSDP's `_post_backward_hook` only registers if `requires_grad=True`. Without this hook, gathered parameters are never freed, causing OOM. Solution: temporarily unfreeze decoder layers and exclude them from the optimizer instead.

Code Evidence

Double-wrapping and temporary unfreezing from `open_flamingo/src/flamingo.py:256-292`:

# unfreeze the decoder layers
for block in self.lang_encoder.old_decoder_blocks:
    block.requires_grad_(True)

# wrap in FSDP
with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
    self.perceiver = wrap(wrap(self.perceiver))
    self.lang_encoder.old_decoder_blocks = nn.ModuleList(
        wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks
    )
    self.lang_encoder.gated_cross_attn_layers = nn.ModuleList(
        wrap(wrap(layer)) if layer is not None else None
        for layer in self.lang_encoder.gated_cross_attn_layers
    )
    ...

# exclude the original decoder layers from the optimizer
for block in self.lang_encoder.old_decoder_blocks:
    for p in block.parameters():
        p.exclude_from_optimizer = True

Known issues documented in `open_flamingo/src/flamingo.py:224-229`:

"""
Known issues:
- Our FSDP strategy is not compatible with tied embeddings. If the LM
  embeddings are tied, train with DDP or set --freeze_lm_embeddings.
- With FSDP + gradient ckpting, one can increase the batch size with
  seemingly no upper bound. Although the training curves look okay,
  we found that downstream performance dramatically degrades if the
  batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M).
"""

Hybrid sharding bug warning from `open_flamingo/train/train.py:239-245`:

if args.fsdp and args.fsdp_sharding_strategy == "hybrid":
    print(
        "Warning: As of torch=2.0.1, the FSDP logic for optim_state_dict() "
        "is broken for hybrid sharding."
        "To make this method work, we need to modify "
        "torch.distributed.fsdp._optim_utils.py"
    )

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment