Implementation:Haotian liu LLaVA Train Stage1 Pretrain
| Knowledge Sources | |
|---|---|
| Domains | |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Concrete tool for running Stage 1 feature alignment pretraining of LLaVA's multimodal projector. The train() function in LLaVA's training pipeline, when invoked with --tune_mm_mlp_adapter True, loads a base LLM, initializes vision modules (CLIP encoder + MLP projector), freezes all parameters except the projector, and trains using DeepSpeed ZeRO-2.
Description
Stage 1 pretraining is orchestrated by the train() function in llava/train/train.py. When tune_mm_mlp_adapter=True is set, the function:
- Loads the base LLM -- Instantiates
LlavaLlamaForCausalLMfrom the pretrained Vicuna-13B weights - Initializes vision modules -- Calls
model.get_model().initialize_vision_modules()which:- Builds the CLIP ViT-L/14-336 vision tower
- Constructs the MLP projector via
build_vision_projector() - If
pretrain_mm_mlp_adapteris set, loads pretrained projector weights (not used in Stage 1)
- Freezes all parameters --
model.requires_grad_(False) - Unfreezes projector only -- Iterates over
model.get_model().mm_projector.parameters()and setsrequires_grad = True - Trains with LLaVATrainer -- Uses HuggingFace Trainer (extended by
LLaVATrainer) with DeepSpeed ZeRO-2 - Saves projector weights --
safe_save_model_for_hf_trainer()detectstune_mm_mlp_adapter=Trueand saves onlymm_projector.bin
Usage
Run the Stage 1 pretraining script:
deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path lmsys/vicuna-13b-v1.5 \
--version plain \
--data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \
--image_folder ./playground/data/LLaVA-Pretrain/images \
--vision_tower openai/clip-vit-large-patch14-336 \
--mm_projector_type mlp2x_gelu \
--tune_mm_mlp_adapter True \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--bf16 True \
--output_dir ./checkpoints/llava-v1.5-13b-pretrain \
--num_train_epochs 1 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 24000 \
--save_total_limit 1 \
--learning_rate 1e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
Source: scripts/v1_5/pretrain.sh
Code Reference
Source Location
- Repository:
https://github.com/haotian-liu/LLaVA - File:
llava/train/train.py, lines 788--991 (train()function) - File:
llava/model/llava_arch.py, lines 49--97 (initialize_vision_modules()) - File:
llava/model/multimodal_projector/builder.py, lines 33--51 (build_vision_projector()) - Script:
scripts/v1_5/pretrain.sh(launcher script)
Signature
def train(attn_implementation=None) -> None:
# Configured via CLI args:
# --model_name_or_path: str (base LLM, e.g. 'lmsys/vicuna-13b-v1.5')
# --tune_mm_mlp_adapter True
# --vision_tower: str (e.g. 'openai/clip-vit-large-patch14-336')
# --mm_projector_type: str = 'mlp2x_gelu'
# --version: str = 'plain' (conversation format for pretraining)
# --bf16 True
# --num_train_epochs 1
# --per_device_train_batch_size 32
# --learning_rate 1e-3
# --deepspeed scripts/zero2.json
...
Import
from llava.train.train import train
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
--model_name_or_path |
str (CLI) | Yes | HuggingFace model ID or local path for the base LLM. E.g., lmsys/vicuna-13b-v1.5.
|
--vision_tower |
str (CLI) | Yes | HuggingFace model ID for the vision encoder. E.g., openai/clip-vit-large-patch14-336.
|
--tune_mm_mlp_adapter |
bool (CLI) | Yes | Must be True for Stage 1. Freezes all params except the projector.
|
--mm_projector_type |
str (CLI) | Yes | Projector architecture. mlp2x_gelu for LLaVA v1.5.
|
--data_path |
str (CLI) | Yes | Path to JSON training data. E.g., blip_laion_cc_sbu_558k.json.
|
--image_folder |
str (CLI) | Yes | Base directory containing training images. |
--version |
str (CLI) | Yes | Conversation template. plain for Stage 1 pretraining.
|
--deepspeed |
str (CLI) | Yes | Path to DeepSpeed config. scripts/zero2.json for Stage 1.
|
Outputs
| Name | Type | Description |
|---|---|---|
mm_projector.bin |
File | Checkpoint containing only the trained projector weights. Saved to --output_dir.
|
config.json |
File | Model configuration file saved alongside the projector weights. |
| Training logs | WandB / stdout | Loss curves, learning rate schedule, and other training metrics. |
Key Implementation Details
Parameter Freezing (train.py, lines 927-930)
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
Checkpoint Saving (train.py, lines 185-207)
When tune_mm_mlp_adapter is True, the custom save function extracts only projector weights:
def safe_save_model_for_hf_trainer(trainer, output_dir):
if getattr(trainer.args, "tune_mm_mlp_adapter", False):
keys_to_match = ['mm_projector']
weight_to_save = get_mm_adapter_state_maybe_zero_3(
trainer.model.named_parameters(), keys_to_match
)
trainer.model.config.save_pretrained(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
torch.save(weight_to_save, os.path.join(output_dir, 'mm_projector.bin'))
return
Vision Module Initialization (llava_arch.py, lines 49-97)
def initialize_vision_modules(self, model_args, fsdp=None):
vision_tower = model_args.vision_tower
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
self.config.mm_vision_tower = vision_tower
if self.get_vision_tower() is None:
vision_tower = build_vision_tower(model_args)
self.vision_tower = vision_tower
self.config.use_mm_proj = True
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
self.config.mm_hidden_size = vision_tower.hidden_size
if getattr(self, 'mm_projector', None) is None:
self.mm_projector = build_vision_projector(self.config)
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
self.mm_projector.load_state_dict(
{k.split('mm_projector.')[1]: v for k, v in mm_projector_weights.items()
if 'mm_projector' in k}
)
Related Pages
- Principle:Haotian_liu_LLaVA_Feature_Alignment_Pretraining
- Environment:Haotian_liu_LLaVA_Python_CUDA_Training_Environment
- Heuristic:Haotian_liu_LLaVA_Flash_Attention_GPU_Requirement
- Heuristic:Haotian_liu_LLaVA_Gradient_Checkpointing_Memory_Optimization
- Heuristic:Haotian_liu_LLaVA_Use_Cache_Training_Inference_Toggle