Principle:Mlfoundations Open flamingo FSDP Model Wrapping
Overview
Memory optimization technique that manually wraps model submodules for Fully Sharded Data Parallelism while handling the constraint that frozen and trainable parameters cannot be mixed within the same FSDP unit.
Description
FSDP shards model parameters across GPUs but requires all parameters within a wrapped module to have the same requires_grad status. OpenFlamingo's architecture has frozen backbones (CLIP, LM) alongside trainable modules (Perceiver, cross-attention), requiring manual wrapping of individual submodules rather than wrapping the entire model. The wrap_fsdp method individually wraps the vision encoder, Perceiver, each gated cross-attention layer, and the language model decoder layers with appropriate mixed precision policies.
Usage
When training large Flamingo models that exceed single-GPU memory; provides sharding of optimizer states and parameters across GPUs.
Theoretical Basis
FSDP implements ZeRO-3 style sharding. Parameters are sharded across ranks and gathered only during forward/backward passes. The challenge with mixed frozen/trainable parameters is that FSDP applies the same gradient computation to all parameters within a unit. Manual wrapping ensures frozen params are in separate FSDP units from trainable params, preventing unnecessary gradient computation on frozen weights.
Related Pages
Implementation:Mlfoundations_Open_flamingo_Flamingo_wrap_fsdp