Principle:Haotian liu LLaVA Checkpoint Extraction
| Knowledge Sources | |
|---|---|
| Domains | |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Technique for extracting specific weight subsets from large model checkpoints for modular reuse and validation. In LLaVA's two-stage training pipeline, checkpoint extraction isolates the multimodal projector weights from full model checkpoints, enabling them to be loaded in subsequent training stages or used to validate model integrity.
Description
After training, model checkpoints may contain millions of parameters across multiple components: the LLM backbone (~13B parameters), the vision encoder (~300M parameters), and the multimodal projector (~30M parameters). Checkpoint extraction provides two essential capabilities:
- Modular weight extraction -- Isolating specific component weights (e.g., the
mm_projector) into separate files for reuse. This is the key bridge between Stage 1 and Stage 2: the pretrained projector weights (mm_projector.bin) from Stage 1 are extracted and loaded into Stage 2 via--pretrain_mm_mlp_adapter.
- Model validation and loading -- After full training (Stage 2), the complete model checkpoint can be loaded and validated using
load_pretrained_model(), which handles various checkpoint formats (full model, LoRA adapters, projector-only).
The extraction process works with two checkpoint formats:
- Single-file checkpoints --
pytorch_model.bincontains all weights in one file. Directly filterable by key name. - Sharded checkpoints -- Multiple
pytorch_model-XXXXX-of-YYYYY.binshard files with an index filepytorch_model.bin.index.jsonthat maps weight names to shard files. The extraction script reads the index to identify which shards contain projector weights, then loads only those shards.
Usage
Use checkpoint extraction in these scenarios:
- After Stage 1 pretraining -- Extract
mm_projector.binfor use in Stage 2 finetuning via the--pretrain_mm_mlp_adapterargument. Note: whentune_mm_mlp_adapter=True, LLaVA's training pipeline automatically saves only the projector weights, so manual extraction is typically not needed for standard training. The extraction script is primarily useful for extracting projector weights from quantized or non-standard checkpoints.
- After full training (Stage 2) -- Validate the model loads correctly using
load_pretrained_model().
- For model variants -- Extract the projector from one fully-trained model to use with a different LLM backbone, enabling mix-and-match experimentation.
Theoretical Basis
The extraction operates via key-name filtering on the model's state dictionary. Weight keys in a LLaVA model follow a hierarchical naming convention:
Full Model State Dict Keys:
model.embed_tokens.weight # LLM embedding
model.layers.0.self_attn.q_proj.weight # LLM attention
model.layers.0.self_attn.k_proj.weight # LLM attention
...
model.layers.39.mlp.gate_proj.weight # LLM MLP (layer 39)
model.mm_projector.0.weight # Projector: Linear(1024, 5120)
model.mm_projector.0.bias # Projector: Linear bias
model.mm_projector.2.weight # Projector: Linear(5120, 5120)
model.mm_projector.2.bias # Projector: Linear bias
lm_head.weight # LLM output head
Extraction Filter:
keys_to_match = ['mm_projector']
extracted = {k: v for k, v in state_dict.items()
if any(match in k for match in keys_to_match)}
For sharded checkpoints, the pytorch_model.bin.index.json provides an efficient lookup:
{
"weight_map": {
"model.embed_tokens.weight": "pytorch_model-00001-of-00003.bin",
"model.mm_projector.0.weight": "pytorch_model-00003-of-00003.bin",
"model.mm_projector.0.bias": "pytorch_model-00003-of-00003.bin",
"model.mm_projector.2.weight": "pytorch_model-00003-of-00003.bin",
"model.mm_projector.2.bias": "pytorch_model-00003-of-00003.bin"
}
}
By scanning the weight map first, the extraction script identifies that only shard 3 needs to be loaded, avoiding the memory cost of loading the full 13B model checkpoint.