Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Haotian liu LLaVA Extract MM Projector

From Leeroopedia
Metadata
Knowledge Sources
Domains
Last Updated 2026-02-13 00:00 GMT

Overview

CLI utility script for extracting multimodal projector weights from a full LLaVA checkpoint. extract_mm_projector.py scans a checkpoint directory for weight keys matching 'mm_projector', loads them from the appropriate shard files, and saves the extracted weights to a separate output file.

Description

extract_mm_projector.py is a standalone utility that extracts the multimodal projector weights from a full LLaVA model checkpoint. It handles two checkpoint formats:

  • Sharded checkpoints -- Reads pytorch_model.bin.index.json to build a mapping from projector weight keys to their containing shard files. Only the relevant shards are loaded into memory, avoiding the cost of loading the full model.
  • Single-file checkpoints -- Falls back to loading pytorch_model.bin directly and filtering by key name. This path handles smaller models and DeepSpeed-saved checkpoints.

The script filters weight keys using substring matching against ['mm_projector'], collecting all matching key-value pairs across shards into a single output dictionary.

Additionally, this page documents load_pretrained_model() from llava/model/builder.py, which provides the complementary functionality of loading and validating extracted or full LLaVA checkpoints for inference.

Usage

python scripts/extract_mm_projector.py \
    --model-path /path/to/full/checkpoint \
    --output mm_projector.bin

Code Reference

Source Location

extract_mm_projector.py (Full Source)

import os
import argparse
import torch
import json
from collections import defaultdict


def parse_args():
    parser = argparse.ArgumentParser(description='Extract MMProjector weights')
    parser.add_argument('--model-path', type=str, help='model folder')
    parser.add_argument('--output', type=str, help='output file')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()

    keys_to_match = ['mm_projector']
    ckpt_to_key = defaultdict(list)
    try:
        model_indices = json.load(
            open(os.path.join(args.model_path, 'pytorch_model.bin.index.json'))
        )
        for k, v in model_indices['weight_map'].items():
            if any(key_match in k for key_match in keys_to_match):
                ckpt_to_key[v].append(k)
    except FileNotFoundError:
        # Smaller models or model checkpoints saved by DeepSpeed.
        v = 'pytorch_model.bin'
        for k in torch.load(
            os.path.join(args.model_path, v), map_location='cpu'
        ).keys():
            if any(key_match in k for key_match in keys_to_match):
                ckpt_to_key[v].append(k)

    loaded_weights = {}

    for ckpt_name, weight_keys in ckpt_to_key.items():
        ckpt = torch.load(
            os.path.join(args.model_path, ckpt_name), map_location='cpu'
        )
        for k in weight_keys:
            loaded_weights[k] = ckpt[k]

    torch.save(loaded_weights, args.output)

load_pretrained_model() Signature

def load_pretrained_model(
    model_path: str,
    model_base: str,
    model_name: str,
    load_8bit: bool = False,
    load_4bit: bool = False,
    device_map: str = "auto",
    device: str = "cuda",
    use_flash_attn: bool = False,
    **kwargs
) -> Tuple[PreTrainedTokenizer, PreTrainedModel, CLIPImageProcessor, int]:
    """
    Load a pretrained LLaVA model for inference.

    Returns:
        tokenizer, model, image_processor, context_len
    """

Import

# extract_mm_projector.py is a CLI script, not typically imported
# For model loading/validation:
from llava.model.builder import load_pretrained_model

I/O Contract

extract_mm_projector.py

Input Contract
Name Type Required Description
--model-path str (CLI) Yes Path to the full model checkpoint directory. Must contain either pytorch_model.bin.index.json (sharded) or pytorch_model.bin (single file).
--output str (CLI) Yes Output file path for the extracted projector weights (e.g., mm_projector.bin).
Output Contract
Name Type Description
mm_projector.bin File PyTorch state dict containing only keys matching 'mm_projector'. Typically 4 entries for mlp2x_gelu: model.mm_projector.{0,2}.{weight,bias}.

load_pretrained_model()

Input Contract
Name Type Required Description
model_path str Yes Path to model checkpoint or HuggingFace model ID.
model_base str Conditional Required for LoRA models and projector-only checkpoints. Path to the base LLM.
model_name str Yes Model name used to determine loading strategy (checks for 'llava', 'lora', 'mpt', 'mistral' substrings).
load_8bit bool No Enable 8-bit quantization. Default: False.
load_4bit bool No Enable 4-bit quantization with NF4. Default: False.
use_flash_attn bool No Enable Flash Attention 2. Default: False.
Output Contract
Name Type Description
tokenizer PreTrainedTokenizer Loaded tokenizer for the model.
model PreTrainedModel Loaded LLaVA model ready for inference.
image_processor CLIPImageProcessor Image preprocessing pipeline from the vision tower.
context_len int Maximum context length (default 2048).

Usage Examples

Example 1: Extract Projector from Sharded Checkpoint

# After Stage 1 pretraining produces a sharded checkpoint:
python scripts/extract_mm_projector.py \
    --model-path ./checkpoints/llava-v1.5-13b-pretrain \
    --output ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin

Example 2: Use Extracted Projector in Stage 2

# The extracted projector feeds into Stage 2 finetuning:
deepspeed llava/train/train_mem.py \
    --deepspeed ./scripts/zero3.json \
    --model_name_or_path lmsys/vicuna-13b-v1.5 \
    --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \
    ...

Example 3: Validate Full Model with load_pretrained_model()

from llava.model.builder import load_pretrained_model

# Load a fully trained LLaVA model for inference validation
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="./checkpoints/llava-v1.5-13b",
    model_base=None,
    model_name="llava-v1.5-13b",
    use_flash_attn=True
)

# For projector-only checkpoints, provide the base model:
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="./checkpoints/llava-v1.5-13b-pretrain",
    model_base="lmsys/vicuna-13b-v1.5",
    model_name="llava-v1.5-13b-pretrain"
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment