Implementation:Mlfoundations Open flamingo Train one epoch
Appearance
Overview
Concrete tool for executing one epoch of dual-dataset training with loss masking, gradient accumulation, and mixed precision provided by the OpenFlamingo training module.
Description
The train_one_epoch() function iterates over paired LAION and MMC4 batches. For each step:
- Forward pass on LAION batch with autocast, compute masked loss on caption tokens.
- Forward pass on MMC4 batch with autocast, compute masked loss on interleaved text.
- Accumulate gradients.
- Clip gradients to max norm 1.0.
- Optimizer step.
- Log losses to wandb.
Label masking sets loss to -100 for tokens before the first <image> token and for <PAD> tokens.
Usage
Called for each epoch in the main training loop after data and optimizer setup.
Code Reference
- Source
- Repository https://github.com/mlfoundations/open_flamingo, File:
open_flamingo/train/train_utils.pyLines L46-278
- Signature
def train_one_epoch(
args,
model,
epoch,
laion_loader,
mmc4_loader,
tokenizer,
optimizer,
lr_scheduler,
device_id,
wandb,
):
"""
Train for one epoch on LAION + MMC4 datasets.
"""
- Import
from open_flamingo.train.train_utils import train_one_epoch
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| args | argparse.Namespace |
Yes | Training config |
| model | Flamingo |
Yes | FSDP/DDP wrapped model |
| epoch | int |
Yes | Current epoch number |
| laion_loader | DataLoader |
Yes | LAION WebLoader |
| mmc4_loader | DataLoader |
Yes | MMC4 WebLoader |
| tokenizer | PreTrainedTokenizer |
Yes | Tokenizer for loss masking |
| optimizer | AdamW |
Yes | Optimizer |
| lr_scheduler | LRScheduler |
Yes | Learning rate scheduler |
| device_id | int |
Yes | GPU device ID |
| wandb | module |
Yes | Weights and Biases logging module or None |
Outputs
- Model weights updated in-place.
- LAION loss and MMC4 loss logged to wandb per step.
Usage Examples
for epoch in range(args.num_epochs):
train_one_epoch(
args=args,
model=model,
epoch=epoch,
laion_loader=laion_loader,
mmc4_loader=mmc4_loader,
tokenizer=tokenizer,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
device_id=device_id,
wandb=wandb,
)
Related Pages
Principle:Mlfoundations_Open_flamingo_Dual_Dataset_Training_Loop
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment