Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Haotian liu LLaVA Checkpoint Extraction

From Leeroopedia
Revision as of 17:21, 16 February 2026 by Admin (talk | contribs) (Auto-imported from principles/Haotian_liu_LLaVA_Checkpoint_Extraction.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Metadata
Knowledge Sources
Domains
Last Updated 2026-02-13 00:00 GMT

Overview

Technique for extracting specific weight subsets from large model checkpoints for modular reuse and validation. In LLaVA's two-stage training pipeline, checkpoint extraction isolates the multimodal projector weights from full model checkpoints, enabling them to be loaded in subsequent training stages or used to validate model integrity.

Description

After training, model checkpoints may contain millions of parameters across multiple components: the LLM backbone (~13B parameters), the vision encoder (~300M parameters), and the multimodal projector (~30M parameters). Checkpoint extraction provides two essential capabilities:

  1. Modular weight extraction -- Isolating specific component weights (e.g., the mm_projector) into separate files for reuse. This is the key bridge between Stage 1 and Stage 2: the pretrained projector weights (mm_projector.bin) from Stage 1 are extracted and loaded into Stage 2 via --pretrain_mm_mlp_adapter.
  1. Model validation and loading -- After full training (Stage 2), the complete model checkpoint can be loaded and validated using load_pretrained_model(), which handles various checkpoint formats (full model, LoRA adapters, projector-only).

The extraction process works with two checkpoint formats:

  • Single-file checkpoints -- pytorch_model.bin contains all weights in one file. Directly filterable by key name.
  • Sharded checkpoints -- Multiple pytorch_model-XXXXX-of-YYYYY.bin shard files with an index file pytorch_model.bin.index.json that maps weight names to shard files. The extraction script reads the index to identify which shards contain projector weights, then loads only those shards.

Usage

Use checkpoint extraction in these scenarios:

  • After Stage 1 pretraining -- Extract mm_projector.bin for use in Stage 2 finetuning via the --pretrain_mm_mlp_adapter argument. Note: when tune_mm_mlp_adapter=True, LLaVA's training pipeline automatically saves only the projector weights, so manual extraction is typically not needed for standard training. The extraction script is primarily useful for extracting projector weights from quantized or non-standard checkpoints.
  • After full training (Stage 2) -- Validate the model loads correctly using load_pretrained_model().
  • For model variants -- Extract the projector from one fully-trained model to use with a different LLM backbone, enabling mix-and-match experimentation.

Theoretical Basis

The extraction operates via key-name filtering on the model's state dictionary. Weight keys in a LLaVA model follow a hierarchical naming convention:

Full Model State Dict Keys:
    model.embed_tokens.weight                    # LLM embedding
    model.layers.0.self_attn.q_proj.weight       # LLM attention
    model.layers.0.self_attn.k_proj.weight       # LLM attention
    ...
    model.layers.39.mlp.gate_proj.weight         # LLM MLP (layer 39)
    model.mm_projector.0.weight                  # Projector: Linear(1024, 5120)
    model.mm_projector.0.bias                    # Projector: Linear bias
    model.mm_projector.2.weight                  # Projector: Linear(5120, 5120)
    model.mm_projector.2.bias                    # Projector: Linear bias
    lm_head.weight                               # LLM output head

Extraction Filter:
    keys_to_match = ['mm_projector']
    extracted = {k: v for k, v in state_dict.items()
                 if any(match in k for match in keys_to_match)}

For sharded checkpoints, the pytorch_model.bin.index.json provides an efficient lookup:

{
    "weight_map": {
        "model.embed_tokens.weight": "pytorch_model-00001-of-00003.bin",
        "model.mm_projector.0.weight": "pytorch_model-00003-of-00003.bin",
        "model.mm_projector.0.bias": "pytorch_model-00003-of-00003.bin",
        "model.mm_projector.2.weight": "pytorch_model-00003-of-00003.bin",
        "model.mm_projector.2.bias": "pytorch_model-00003-of-00003.bin"
    }
}

By scanning the weight map first, the extraction script identifies that only shard 3 needs to be loaded, avoiding the memory cost of loading the full 13B model checkpoint.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment