Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Haotian liu LLaVA Train Stage1 Pretrain

From Leeroopedia
Metadata
Knowledge Sources
Domains
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete tool for running Stage 1 feature alignment pretraining of LLaVA's multimodal projector. The train() function in LLaVA's training pipeline, when invoked with --tune_mm_mlp_adapter True, loads a base LLM, initializes vision modules (CLIP encoder + MLP projector), freezes all parameters except the projector, and trains using DeepSpeed ZeRO-2.

Description

Stage 1 pretraining is orchestrated by the train() function in llava/train/train.py. When tune_mm_mlp_adapter=True is set, the function:

  1. Loads the base LLM -- Instantiates LlavaLlamaForCausalLM from the pretrained Vicuna-13B weights
  2. Initializes vision modules -- Calls model.get_model().initialize_vision_modules() which:
    • Builds the CLIP ViT-L/14-336 vision tower
    • Constructs the MLP projector via build_vision_projector()
    • If pretrain_mm_mlp_adapter is set, loads pretrained projector weights (not used in Stage 1)
  3. Freezes all parameters -- model.requires_grad_(False)
  4. Unfreezes projector only -- Iterates over model.get_model().mm_projector.parameters() and sets requires_grad = True
  5. Trains with LLaVATrainer -- Uses HuggingFace Trainer (extended by LLaVATrainer) with DeepSpeed ZeRO-2
  6. Saves projector weights -- safe_save_model_for_hf_trainer() detects tune_mm_mlp_adapter=True and saves only mm_projector.bin

Usage

Run the Stage 1 pretraining script:

deepspeed llava/train/train_mem.py \
    --deepspeed ./scripts/zero2.json \
    --model_name_or_path lmsys/vicuna-13b-v1.5 \
    --version plain \
    --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \
    --image_folder ./playground/data/LLaVA-Pretrain/images \
    --vision_tower openai/clip-vit-large-patch14-336 \
    --mm_projector_type mlp2x_gelu \
    --tune_mm_mlp_adapter True \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --bf16 True \
    --output_dir ./checkpoints/llava-v1.5-13b-pretrain \
    --num_train_epochs 1 \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 24000 \
    --save_total_limit 1 \
    --learning_rate 1e-3 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb

Source: scripts/v1_5/pretrain.sh

Code Reference

Source Location

  • Repository: https://github.com/haotian-liu/LLaVA
  • File: llava/train/train.py, lines 788--991 (train() function)
  • File: llava/model/llava_arch.py, lines 49--97 (initialize_vision_modules())
  • File: llava/model/multimodal_projector/builder.py, lines 33--51 (build_vision_projector())
  • Script: scripts/v1_5/pretrain.sh (launcher script)

Signature

def train(attn_implementation=None) -> None:
    # Configured via CLI args:
    # --model_name_or_path: str (base LLM, e.g. 'lmsys/vicuna-13b-v1.5')
    # --tune_mm_mlp_adapter True
    # --vision_tower: str (e.g. 'openai/clip-vit-large-patch14-336')
    # --mm_projector_type: str = 'mlp2x_gelu'
    # --version: str = 'plain'  (conversation format for pretraining)
    # --bf16 True
    # --num_train_epochs 1
    # --per_device_train_batch_size 32
    # --learning_rate 1e-3
    # --deepspeed scripts/zero2.json
    ...

Import

from llava.train.train import train

I/O Contract

Inputs

Input Contract
Name Type Required Description
--model_name_or_path str (CLI) Yes HuggingFace model ID or local path for the base LLM. E.g., lmsys/vicuna-13b-v1.5.
--vision_tower str (CLI) Yes HuggingFace model ID for the vision encoder. E.g., openai/clip-vit-large-patch14-336.
--tune_mm_mlp_adapter bool (CLI) Yes Must be True for Stage 1. Freezes all params except the projector.
--mm_projector_type str (CLI) Yes Projector architecture. mlp2x_gelu for LLaVA v1.5.
--data_path str (CLI) Yes Path to JSON training data. E.g., blip_laion_cc_sbu_558k.json.
--image_folder str (CLI) Yes Base directory containing training images.
--version str (CLI) Yes Conversation template. plain for Stage 1 pretraining.
--deepspeed str (CLI) Yes Path to DeepSpeed config. scripts/zero2.json for Stage 1.

Outputs

Output Contract
Name Type Description
mm_projector.bin File Checkpoint containing only the trained projector weights. Saved to --output_dir.
config.json File Model configuration file saved alongside the projector weights.
Training logs WandB / stdout Loss curves, learning rate schedule, and other training metrics.

Key Implementation Details

Parameter Freezing (train.py, lines 927-930)

model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
    model.requires_grad_(False)
    for p in model.get_model().mm_projector.parameters():
        p.requires_grad = True

Checkpoint Saving (train.py, lines 185-207)

When tune_mm_mlp_adapter is True, the custom save function extracts only projector weights:

def safe_save_model_for_hf_trainer(trainer, output_dir):
    if getattr(trainer.args, "tune_mm_mlp_adapter", False):
        keys_to_match = ['mm_projector']
        weight_to_save = get_mm_adapter_state_maybe_zero_3(
            trainer.model.named_parameters(), keys_to_match
        )
        trainer.model.config.save_pretrained(output_dir)
        if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
            torch.save(weight_to_save, os.path.join(output_dir, 'mm_projector.bin'))
        return

Vision Module Initialization (llava_arch.py, lines 49-97)

def initialize_vision_modules(self, model_args, fsdp=None):
    vision_tower = model_args.vision_tower
    pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter

    self.config.mm_vision_tower = vision_tower

    if self.get_vision_tower() is None:
        vision_tower = build_vision_tower(model_args)
        self.vision_tower = vision_tower

    self.config.use_mm_proj = True
    self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
    self.config.mm_hidden_size = vision_tower.hidden_size

    if getattr(self, 'mm_projector', None) is None:
        self.mm_projector = build_vision_projector(self.config)

    if pretrain_mm_mlp_adapter is not None:
        mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
        self.mm_projector.load_state_dict(
            {k.split('mm_projector.')[1]: v for k, v in mm_projector_weights.items()
             if 'mm_projector' in k}
        )

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment