Workflow:Mlfoundations Open flamingo Distributed Training
| Knowledge Sources | |
|---|---|
| Domains | Vision_Language_Models, Distributed_Training, LLM_Ops |
| Last Updated | 2026-02-08 03:30 GMT |
Overview
End-to-end process for training an OpenFlamingo vision-language model on dual datasets (LAION image-text pairs and MMC4 interleaved image-text documents) using distributed data parallelism or fully sharded data parallelism.
Description
This workflow covers the complete training pipeline for OpenFlamingo models, from environment setup through distributed training to checkpoint saving. The training process uses a dual-dataset strategy: LAION provides single image-caption pairs for basic vision-language alignment, while MMC4 (Multimodal C4) provides interleaved multi-image documents for learning in-context capabilities. Both datasets are served through WebDataset pipelines. The model's selective freezing strategy trains only the Perceiver Resampler, gated cross-attention layers, and optionally the new special token embeddings, while keeping the vision encoder and language model backbone frozen. The training supports both DDP and FSDP distributed strategies with mixed precision, gradient checkpointing, and W&B logging.
Usage
Execute this workflow when you want to train an OpenFlamingo model from scratch or continue training from a checkpoint. You need access to LAION image-text shards and MMC4 interleaved document shards in WebDataset tar format, along with multi-GPU compute resources. The workflow is typically launched via torchrun or SLURM.
Execution Steps
Step 1: Prepare Environment And Data
Set up the training environment with all required dependencies including PyTorch, OpenCLIP, Transformers, WebDataset, and optionally W&B for logging. Prepare the training data as WebDataset tar shards: LAION shards containing image-text pairs, and MMC4 shards containing interleaved multi-image documents with base64-encoded images. Configure SLURM or torchrun for multi-GPU distributed execution.
Key considerations:
- Install the training extras via pip install open-flamingo[training]
- LAION shards should be in standard WebDataset format (image + text per sample)
- MMC4 shards need to be pre-converted using the convert_mmc4_to_wds script
- Ensure the number of samples per epoch is balanced between LAION and MMC4 (adjusted by batch sizes)
Step 2: Initialize Model
Create the OpenFlamingo model using the factory function with the desired vision encoder (e.g., ViT-L-14) and language model (e.g., MPT-1B). The factory function handles CLIP loading, language model loading, mixin injection for cross-attention layers, special token addition, and parameter freezing. Optionally enable gradient checkpointing for memory efficiency.
Key considerations:
- The cross_attn_every_n_layers parameter controls architectural complexity vs. performance
- Gradient checkpointing trades compute for memory, enabling larger batch sizes
- The freeze_lm_embeddings flag controls whether the <image> and <|endofchunk|> embeddings are trained
- Model parameters: only Perceiver + cross-attention layers are trainable (~1-3% of total)
Step 3: Set Up Distributed Training
Initialize the distributed training backend and wrap the model for distributed execution. For DDP, the model is moved to the assigned GPU and wrapped with DistributedDataParallel. For FSDP, the model undergoes custom manual wrapping that respects the mixed frozen/unfrozen parameter requirements: each submodule (Perceiver, cross-attention layers, decoder layers, embeddings) is individually wrapped with appropriate sharding. Configure mixed precision policies for memory-efficient training.
Key considerations:
- FSDP requires fsdp_use_orig_params=True for proper weight decay and gradient masking (except OPT models)
- FSDP uses double-wrapping to ensure proper memory management in post-forward/backward hooks
- Frozen decoder layers are unfrozen for FSDP compatibility but excluded from the optimizer
- Hybrid sharding strategy requires patched _optim_utils.py for optimizer state management
- Supports torchrun, SLURM, Horovod, and manual DDP initialization
Step 4: Configure Optimizer And Scheduler
Set up the AdamW optimizer with differential weight decay: cross-attention layer parameters receive full weight decay while other trainable parameters (Perceiver, embeddings) receive zero weight decay. Initialize the learning rate scheduler (constant, linear, or cosine with warmup). Optionally resume optimizer and scheduler states from a checkpoint.
Key considerations:
- Weight decay is applied only to gated_cross_attn parameters when using fsdp_use_orig_params
- Without fsdp_use_orig_params, uniform weight decay is applied to all parameters (suboptimal)
- Warmup steps help stabilize early training
- Total training steps are computed from samples-per-epoch divided by effective batch size
Step 5: Load Data Pipelines
Initialize the dual WebDataset data pipelines for LAION and MMC4. The LAION pipeline loads image-text pairs, applies image augmentation (random horizontal flip), and tokenizes captions with the Flamingo prompt format. The MMC4 pipeline loads interleaved documents, selects images based on text-similarity thresholds (using the Hungarian algorithm for alignment), and constructs multi-image sequences with proper special token formatting.
Key considerations:
- LAION captions are formatted as "<image>{caption}<|endofchunk|>{eos_token}" and truncated to 32 tokens
- MMC4 sequences contain multiple images with text between them, bounded by max_num_images
- Image-text similarity filtering in MMC4 uses a configurable threshold (mmc4_textsim_threshold)
- Both datasets support deterministic shard shuffling for reproducibility across epochs
- S3 paths are supported via pipe commands
Step 6: Execute Training Loop
Run the training loop over epochs, alternating between LAION and MMC4 forward passes per batch. For each batch: compute the LAION loss on single image-caption pairs, compute the MMC4 loss on interleaved sequences with careful label masking (loss only on text following images, not on inter-image text or padding). Accumulate gradients, apply gradient clipping, mask embedding gradients to only update special tokens, and step the optimizer. Log metrics to W&B including per-dataset losses and throughput.
Key considerations:
- Labels are masked so loss is computed only on text tokens following <image> tokens
- In MMC4, text between <|endofchunk|> and the next <image> is also masked from loss
- Gradient clipping is applied with max_norm=1.0 (FSDP uses per-module clipping)
- Embedding gradients are masked to only update <image> and <|endofchunk|> tokens
- NaN loss detection skips problematic MMC4 batches (not FSDP-compatible)
- Loss multipliers control the relative weight of LAION vs. MMC4 objectives
Step 7: Save Checkpoints
Save model, optimizer, and learning rate scheduler state at the end of each epoch. For FSDP, use the full state dict protocol to gather sharded parameters to rank 0 with CPU offloading. Filter the state dict to only include trainable parameters (plus embeddings for consistency). Optionally save checkpoints to W&B and delete previous checkpoints to save disk space.
Key considerations:
- Only rank 0 saves the checkpoint to avoid redundant writes
- The state dict is filtered to exclude frozen parameters (reduces checkpoint size significantly)
- FSDP checkpointing uses FullStateDictConfig with rank0_only and offload_to_cpu
- Checkpoints include epoch number for automatic resume detection
- Resume logic searches for existing checkpoints in the run directory