Implementation:Mlfoundations Open flamingo Flamingo wrap fsdp
Overview
Concrete tool for manually wrapping OpenFlamingo submodules with FSDP while preserving the frozen/trainable parameter boundary provided by the Flamingo class.
Description
The Flamingo.wrap_fsdp() method individually wraps:
vision_encoder.visualwith FSDPperceiverwith FSDP- Each gated cross-attention layer in
lang_encoderwith FSDP - Each decoder layer in
lang_encoderwith FSDP
Parameters not in FSDP-wrapped submodules are manually moved to the target device. This is necessary because auto-wrapping would mix frozen and trainable parameters.
Usage
After model creation and before optimizer setup when using FSDP distributed training.
Code Reference
Source: Repository https://github.com/mlfoundations/open_flamingo, File: open_flamingo/src/flamingo.py Lines L202-301
Signature:
def wrap_fsdp(self, wrapper_kwargs: dict, device_id: torch.device):
"""
Manually wraps submodules for FSDP and moves other parameters to device_id.
Args:
wrapper_kwargs: dict of kwargs to pass to FSDP constructor (e.g. mixed_precision, sharding_strategy)
device_id: torch.device for non-FSDP parameters
"""
Import: Method on Flamingo model, accessed as model.wrap_fsdp(...). Flamingo from from open_flamingo import create_model_and_transforms.
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
| wrapper_kwargs | dict | Yes | FSDP constructor kwargs including mixed_precision and sharding_strategy |
| device_id | torch.device | Yes | Target device for non-FSDP parameters |
Outputs
| Output | Type | Description |
|---|---|---|
| model | in-place | Flamingo model with submodules FSDP-wrapped and remaining params on device_id |
Usage Examples
FSDP wrapping with mixed precision bf16 policy:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
import torch
# Define mixed precision policy with bf16
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
wrapper_kwargs = {
"mixed_precision": mp_policy,
}
device_id = torch.device("cuda", local_rank)
# Wrap the model submodules with FSDP
model.wrap_fsdp(wrapper_kwargs=wrapper_kwargs, device_id=device_id)
# Now set up optimizer (only after FSDP wrapping)
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=1e-4,
)