Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlfoundations Open flamingo Flamingo wrap fsdp

From Leeroopedia


Template:Metadata

Overview

Concrete tool for manually wrapping OpenFlamingo submodules with FSDP while preserving the frozen/trainable parameter boundary provided by the Flamingo class.

Description

The Flamingo.wrap_fsdp() method individually wraps:

  1. vision_encoder.visual with FSDP
  2. perceiver with FSDP
  3. Each gated cross-attention layer in lang_encoder with FSDP
  4. Each decoder layer in lang_encoder with FSDP

Parameters not in FSDP-wrapped submodules are manually moved to the target device. This is necessary because auto-wrapping would mix frozen and trainable parameters.

Usage

After model creation and before optimizer setup when using FSDP distributed training.

Code Reference

Source: Repository https://github.com/mlfoundations/open_flamingo, File: open_flamingo/src/flamingo.py Lines L202-301

Signature:

def wrap_fsdp(self, wrapper_kwargs: dict, device_id: torch.device):
    """
    Manually wraps submodules for FSDP and moves other parameters to device_id.

    Args:
        wrapper_kwargs: dict of kwargs to pass to FSDP constructor (e.g. mixed_precision, sharding_strategy)
        device_id: torch.device for non-FSDP parameters
    """

Import: Method on Flamingo model, accessed as model.wrap_fsdp(...). Flamingo from from open_flamingo import create_model_and_transforms.

I/O Contract

Inputs

Parameter Type Required Description
wrapper_kwargs dict Yes FSDP constructor kwargs including mixed_precision and sharding_strategy
device_id torch.device Yes Target device for non-FSDP parameters

Outputs

Output Type Description
model in-place Flamingo model with submodules FSDP-wrapped and remaining params on device_id

Usage Examples

FSDP wrapping with mixed precision bf16 policy:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
import torch

# Define mixed precision policy with bf16
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

wrapper_kwargs = {
    "mixed_precision": mp_policy,
}

device_id = torch.device("cuda", local_rank)

# Wrap the model submodules with FSDP
model.wrap_fsdp(wrapper_kwargs=wrapper_kwargs, device_id=device_id)

# Now set up optimizer (only after FSDP wrapping)
optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=1e-4,
)

Related Pages

Principle:Mlfoundations_Open_flamingo_FSDP_Model_Wrapping

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment