Heuristic:Mlfoundations Open flamingo Loss Masking Strategy
| Knowledge Sources | |
|---|---|
| Domains | LLMs, Optimization, Vision_Language |
| Last Updated | 2026-02-08 03:30 GMT |
Overview
Training loss masking strategy that only supervises text tokens appearing after `<image>` tokens and before the next `<|endofchunk|>` token, preventing the model from learning to predict non-visual content.
Description
During OpenFlamingo training, the loss is computed only on specific text tokens to ensure the model learns vision-conditioned generation rather than generic text generation. For LAION (single image-text pairs), pad tokens and `<image>` tokens are masked to -100. For MMC4 (interleaved image-text sequences), a more sophisticated masking scheme is applied: all tokens before the first `<image>` token are masked, and tokens between `<|endofchunk|>` and the next `<image>` are also masked. This ensures loss is only computed on text that is conditioned on a preceding image.
Usage
Apply this heuristic when training multimodal models with interleaved image-text data. It prevents the model from learning to generate text tokens that appear before any visual context has been provided, which would otherwise degrade vision-conditioned generation quality.
The Insight (Rule of Thumb)
- Action: Set labels to -100 for all tokens before the first `<image>` token, all `<image>` tokens, all pad tokens, and all tokens between `<|endofchunk|>` and the next `<image>`.
- Value: Use PyTorch's -100 ignore index for `CrossEntropyLoss`.
- Trade-off: Reduces the effective number of supervised tokens per batch, but ensures only visually-grounded text receives supervision.
- Compatibility: Applied differently for LAION (simple) vs MMC4 (complex interleaved masking).
Reasoning
In interleaved image-text sequences, not all text is conditioned on an image. Text appearing before the first `<image>` token has no visual context, so supervising it would teach the model to generate text without visual grounding. Similarly, text between chunks (after `<|endofchunk|>` and before the next `<image>`) may describe different images. The Flamingo architecture conditions text generation on the most recent preceding image, so loss should only be computed on tokens that have a valid visual context.
An additional edge case is handled: if a sample has only one image at the very end of the sequence, all labels would be -100, so such samples are rejected during preprocessing (data.py:254-263).
Code Evidence
LAION loss masking from `open_flamingo/train/train_utils.py:102-106`:
# set up labels; language model is expected to handle shifting
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
labels[labels == media_token_id] = -100
labels = labels.to(device_id)
MMC4 loss masking from `open_flamingo/train/train_utils.py:127-149`:
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
for i in range(labels.shape[0]):
# remove loss for any token before the first <image> token
label_idx = 0
while (
label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id
):
labels[i][label_idx] = -100
label_idx += 1
# get index of all endofchunk tokens in the sequence
endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0]
for endofchunk_idx in endofchunk_idxs:
token_idx = endofchunk_idx + 1
while (
token_idx < labels.shape[1]
and labels[i][token_idx] != media_token_id
):
labels[i][token_idx] = -100
token_idx += 1
labels[labels == media_token_id] = -100