Implementation:Mlfoundations Open flamingo Hf hub download load state dict
Overview
Wrapper pattern combining HuggingFace Hub checkpoint download with PyTorch partial state dict loading for OpenFlamingo models.
Description
This is a Wrapper Doc. The pattern uses huggingface_hub.hf_hub_download() to download checkpoint files from HuggingFace model repositories, then model.load_state_dict(torch.load(path), strict=False) to load only the Perceiver and cross-attention weights. The strict=False flag is essential because the checkpoint only contains the trainable parameters, not the frozen CLIP or LM backbone weights.
The two-step workflow operates as follows:
- Download —
hf_hub_download()retrieves a specific file from a HuggingFace model repository, caching it locally and returning the local file path. - Load —
torch.load()deserializes the checkpoint into a state dictionary, which is then passed tomodel.load_state_dict()withstrict=Falseto perform partial weight restoration.
This separation ensures that checkpoint acquisition and weight initialization are independent operations, enabling caching, offline usage, and flexible device mapping.
Usage
After creating a model with create_model_and_transforms, load trained OpenFlamingo weights by downloading the checkpoint from HuggingFace Hub and applying the partial state dict to the model.
Code Reference
Source: Repository https://github.com/mlfoundations/open_flamingo, File: open_flamingo/src/flamingo.py, Lines: L17-339 (Flamingo class inherits nn.Module with standard load_state_dict)
Signature (two-step pattern):
# Step 1: Download checkpoint
checkpoint_path = huggingface_hub.hf_hub_download(
repo_id: str, # HuggingFace model repo e.g. "openflamingo/OpenFlamingo-3B-vitl-mpt1b"
filename: str, # Checkpoint filename e.g. "checkpoint.pt"
) -> str # Returns local path to downloaded file
# Step 2: Load weights
model.load_state_dict(
torch.load(checkpoint_path, map_location=device),
strict=False, # Allow partial loading (only Perceiver + cross-attention weights)
)
Import:
from huggingface_hub import hf_hub_download
import torch
External reference: https://huggingface.co/docs/huggingface_hub/package_reference/file_download
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
| repo_id | str | Yes | HuggingFace model repository ID |
| filename | str | Yes | Checkpoint filename |
| device | torch.device | No | Device for loading weights |
Outputs
| Output | Type | Description |
|---|---|---|
| checkpoint_path | str | Local path to downloaded file |
| model state | in-place | Model weights updated with trained Perceiver + cross-attention parameters |
Usage Examples
from open_flamingo import create_model_and_transforms
from huggingface_hub import hf_hub_download
import torch
# Create model architecture
model, image_processor, tokenizer = create_model_and_transforms(
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
cross_attn_every_n_layers=1,
)
# Step 1: Download the checkpoint from HuggingFace Hub
checkpoint_path = hf_hub_download(
repo_id="openflamingo/OpenFlamingo-3B-vitl-mpt1b",
filename="checkpoint.pt",
)
# Step 2: Load the trained Perceiver + cross-attention weights
model.load_state_dict(
torch.load(checkpoint_path, map_location="cpu"),
strict=False,
)
model.eval()
Related Pages
Principle:Mlfoundations_Open_flamingo_Pretrained_Weight_Loading