Overview
CLI utility script for extracting multimodal projector weights from a full LLaVA checkpoint. extract_mm_projector.py scans a checkpoint directory for weight keys matching 'mm_projector', loads them from the appropriate shard files, and saves the extracted weights to a separate output file.
Description
extract_mm_projector.py is a standalone utility that extracts the multimodal projector weights from a full LLaVA model checkpoint. It handles two checkpoint formats:
- Sharded checkpoints -- Reads
pytorch_model.bin.index.json to build a mapping from projector weight keys to their containing shard files. Only the relevant shards are loaded into memory, avoiding the cost of loading the full model.
- Single-file checkpoints -- Falls back to loading
pytorch_model.bin directly and filtering by key name. This path handles smaller models and DeepSpeed-saved checkpoints.
The script filters weight keys using substring matching against ['mm_projector'], collecting all matching key-value pairs across shards into a single output dictionary.
Additionally, this page documents load_pretrained_model() from llava/model/builder.py, which provides the complementary functionality of loading and validating extracted or full LLaVA checkpoints for inference.
Usage
python scripts/extract_mm_projector.py \
--model-path /path/to/full/checkpoint \
--output mm_projector.bin
Code Reference
Source Location
import os
import argparse
import torch
import json
from collections import defaultdict
def parse_args():
parser = argparse.ArgumentParser(description='Extract MMProjector weights')
parser.add_argument('--model-path', type=str, help='model folder')
parser.add_argument('--output', type=str, help='output file')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
keys_to_match = ['mm_projector']
ckpt_to_key = defaultdict(list)
try:
model_indices = json.load(
open(os.path.join(args.model_path, 'pytorch_model.bin.index.json'))
)
for k, v in model_indices['weight_map'].items():
if any(key_match in k for key_match in keys_to_match):
ckpt_to_key[v].append(k)
except FileNotFoundError:
# Smaller models or model checkpoints saved by DeepSpeed.
v = 'pytorch_model.bin'
for k in torch.load(
os.path.join(args.model_path, v), map_location='cpu'
).keys():
if any(key_match in k for key_match in keys_to_match):
ckpt_to_key[v].append(k)
loaded_weights = {}
for ckpt_name, weight_keys in ckpt_to_key.items():
ckpt = torch.load(
os.path.join(args.model_path, ckpt_name), map_location='cpu'
)
for k in weight_keys:
loaded_weights[k] = ckpt[k]
torch.save(loaded_weights, args.output)
load_pretrained_model() Signature
def load_pretrained_model(
model_path: str,
model_base: str,
model_name: str,
load_8bit: bool = False,
load_4bit: bool = False,
device_map: str = "auto",
device: str = "cuda",
use_flash_attn: bool = False,
**kwargs
) -> Tuple[PreTrainedTokenizer, PreTrainedModel, CLIPImageProcessor, int]:
"""
Load a pretrained LLaVA model for inference.
Returns:
tokenizer, model, image_processor, context_len
"""
Import
# extract_mm_projector.py is a CLI script, not typically imported
# For model loading/validation:
from llava.model.builder import load_pretrained_model
I/O Contract
Input Contract
| Name |
Type |
Required |
Description
|
--model-path |
str (CLI) |
Yes |
Path to the full model checkpoint directory. Must contain either pytorch_model.bin.index.json (sharded) or pytorch_model.bin (single file).
|
--output |
str (CLI) |
Yes |
Output file path for the extracted projector weights (e.g., mm_projector.bin).
|
Output Contract
| Name |
Type |
Description
|
mm_projector.bin |
File |
PyTorch state dict containing only keys matching 'mm_projector'. Typically 4 entries for mlp2x_gelu: model.mm_projector.{0,2}.{weight,bias}.
|
load_pretrained_model()
Input Contract
| Name |
Type |
Required |
Description
|
model_path |
str |
Yes |
Path to model checkpoint or HuggingFace model ID.
|
model_base |
str |
Conditional |
Required for LoRA models and projector-only checkpoints. Path to the base LLM.
|
model_name |
str |
Yes |
Model name used to determine loading strategy (checks for 'llava', 'lora', 'mpt', 'mistral' substrings).
|
load_8bit |
bool |
No |
Enable 8-bit quantization. Default: False.
|
load_4bit |
bool |
No |
Enable 4-bit quantization with NF4. Default: False.
|
use_flash_attn |
bool |
No |
Enable Flash Attention 2. Default: False.
|
Output Contract
| Name |
Type |
Description
|
tokenizer |
PreTrainedTokenizer |
Loaded tokenizer for the model.
|
model |
PreTrainedModel |
Loaded LLaVA model ready for inference.
|
image_processor |
CLIPImageProcessor |
Image preprocessing pipeline from the vision tower.
|
context_len |
int |
Maximum context length (default 2048).
|
Usage Examples
# After Stage 1 pretraining produces a sharded checkpoint:
python scripts/extract_mm_projector.py \
--model-path ./checkpoints/llava-v1.5-13b-pretrain \
--output ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin
# The extracted projector feeds into Stage 2 finetuning:
deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path lmsys/vicuna-13b-v1.5 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \
...
Example 3: Validate Full Model with load_pretrained_model()
from llava.model.builder import load_pretrained_model
# Load a fully trained LLaVA model for inference validation
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path="./checkpoints/llava-v1.5-13b",
model_base=None,
model_name="llava-v1.5-13b",
use_flash_attn=True
)
# For projector-only checkpoints, provide the base model:
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path="./checkpoints/llava-v1.5-13b-pretrain",
model_base="lmsys/vicuna-13b-v1.5",
model_name="llava-v1.5-13b-pretrain"
)
Related Pages