Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Mlfoundations Open flamingo Hf hub download load state dict

From Leeroopedia


Template:Metadata

Overview

Wrapper pattern combining HuggingFace Hub checkpoint download with PyTorch partial state dict loading for OpenFlamingo models.

Description

This is a Wrapper Doc. The pattern uses huggingface_hub.hf_hub_download() to download checkpoint files from HuggingFace model repositories, then model.load_state_dict(torch.load(path), strict=False) to load only the Perceiver and cross-attention weights. The strict=False flag is essential because the checkpoint only contains the trainable parameters, not the frozen CLIP or LM backbone weights.

The two-step workflow operates as follows:

  1. Downloadhf_hub_download() retrieves a specific file from a HuggingFace model repository, caching it locally and returning the local file path.
  2. Loadtorch.load() deserializes the checkpoint into a state dictionary, which is then passed to model.load_state_dict() with strict=False to perform partial weight restoration.

This separation ensures that checkpoint acquisition and weight initialization are independent operations, enabling caching, offline usage, and flexible device mapping.

Usage

After creating a model with create_model_and_transforms, load trained OpenFlamingo weights by downloading the checkpoint from HuggingFace Hub and applying the partial state dict to the model.

Code Reference

Source: Repository https://github.com/mlfoundations/open_flamingo, File: open_flamingo/src/flamingo.py, Lines: L17-339 (Flamingo class inherits nn.Module with standard load_state_dict)

Signature (two-step pattern):

# Step 1: Download checkpoint
checkpoint_path = huggingface_hub.hf_hub_download(
    repo_id: str,        # HuggingFace model repo e.g. "openflamingo/OpenFlamingo-3B-vitl-mpt1b"
    filename: str,       # Checkpoint filename e.g. "checkpoint.pt"
) -> str  # Returns local path to downloaded file

# Step 2: Load weights
model.load_state_dict(
    torch.load(checkpoint_path, map_location=device),
    strict=False,  # Allow partial loading (only Perceiver + cross-attention weights)
)

Import:

from huggingface_hub import hf_hub_download
import torch

External reference: https://huggingface.co/docs/huggingface_hub/package_reference/file_download

I/O Contract

Inputs

Parameter Type Required Description
repo_id str Yes HuggingFace model repository ID
filename str Yes Checkpoint filename
device torch.device No Device for loading weights

Outputs

Output Type Description
checkpoint_path str Local path to downloaded file
model state in-place Model weights updated with trained Perceiver + cross-attention parameters

Usage Examples

from open_flamingo import create_model_and_transforms
from huggingface_hub import hf_hub_download
import torch

# Create model architecture
model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
    tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
    cross_attn_every_n_layers=1,
)

# Step 1: Download the checkpoint from HuggingFace Hub
checkpoint_path = hf_hub_download(
    repo_id="openflamingo/OpenFlamingo-3B-vitl-mpt1b",
    filename="checkpoint.pt",
)

# Step 2: Load the trained Perceiver + cross-attention weights
model.load_state_dict(
    torch.load(checkpoint_path, map_location="cpu"),
    strict=False,
)

model.eval()

Related Pages

Principle:Mlfoundations_Open_flamingo_Pretrained_Weight_Loading

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment