Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Haotian liu LLaVA LLaVATrainer Train

From Leeroopedia
Metadata
Knowledge Sources
Domains
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete tool for running Stage 2 visual instruction tuning using LLaVA's custom Trainer class. LLaVATrainer extends HuggingFace's Trainer with modality-length-grouped sampling, separate optimizer parameter groups for the projector, and custom checkpoint saving logic.

Description

LLaVATrainer extends HuggingFace's Trainer with three key modifications tailored for multimodal vision-language training:

  1. Modality-length-grouped sampling (_get_train_sampler()) -- When group_by_modality_length=True, returns a custom LengthGroupedSampler that uses the dataset's modality_lengths property. This separates image-containing (positive length) and text-only (negative length) samples into distinct mega-batches, reducing padding waste.
  1. Separate projector learning rate (create_optimizer()) -- When mm_projector_lr is set, creates four optimizer parameter groups: LLM parameters with/without weight decay at the base learning rate, and projector parameters with/without weight decay at mm_projector_lr. This enables training the projector at a different rate than the LLM backbone.
  1. Custom checkpoint saving (_save_checkpoint() and _save()) -- When tune_mm_mlp_adapter=True (Stage 1 mode), only projector weights are saved via get_mm_adapter_state_maybe_zero_3(). For Stage 2, the default Trainer save behavior is used, which leverages DeepSpeed ZeRO-3's stage3_gather_16bit_weights_on_model_save to reconstruct the full model.

Usage

Run the Stage 2 finetuning script:

deepspeed llava/train/train_mem.py \
    --deepspeed ./scripts/zero3.json \
    --model_name_or_path lmsys/vicuna-13b-v1.5 \
    --version v1 \
    --data_path ./playground/data/llava_v1_5_mix665k.json \
    --image_folder ./playground/data \
    --vision_tower openai/clip-vit-large-patch14-336 \
    --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir ./checkpoints/llava-v1.5-13b \
    --num_train_epochs 1 \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 50000 \
    --save_total_limit 1 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb

Source: scripts/v1_5/finetune.sh

Code Reference

Source Location

  • Repository: https://github.com/haotian-liu/LLaVA
  • File: llava/train/llava_trainer.py, lines 133--255 (LLaVATrainer class)
  • File: llava/train/train.py, lines 788--991 (train() function)
  • Script: scripts/v1_5/finetune.sh (launcher script)

Signature

class LLaVATrainer(Trainer):

    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        """Returns LengthGroupedSampler when group_by_modality_length=True."""
        ...

    def create_optimizer(self):
        """Creates optimizer with optional separate mm_projector_lr."""
        ...

    def _save_checkpoint(self, model, trial, metrics=None):
        """Saves only projector weights when tune_mm_mlp_adapter=True."""
        ...

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        """No-op when tune_mm_mlp_adapter=True; full save otherwise."""
        ...

Import

from llava.train.llava_trainer import LLaVATrainer

I/O Contract

Inputs

Input Contract (Key CLI Arguments for Stage 2)
Name Type Required Description
--pretrain_mm_mlp_adapter str (CLI) Yes Path to Stage 1 pretrained projector weights (mm_projector.bin).
--group_by_modality_length bool (CLI) Yes Enables modality-length-grouped sampling. Set to True for Stage 2.
--mm_projector_lr float (CLI) No Optional separate learning rate for the projector. If not set, the projector uses the same LR as the LLM.
--version str (CLI) Yes Conversation template. v1 for Stage 2 finetuning (multi-turn Vicuna format).
--image_aspect_ratio str (CLI) Yes pad for Stage 2 (preserves aspect ratio via padding).
--data_path str (CLI) Yes Path to 665K instruction tuning data (llava_v1_5_mix665k.json).
--num_train_epochs int (CLI) Yes Number of training epochs. 1 for Stage 2.
--per_device_train_batch_size int (CLI) Yes Per-GPU batch size. 16 for Stage 2.
--learning_rate float (CLI) Yes Base learning rate. 2e-5 for Stage 2.
--deepspeed str (CLI) Yes Path to DeepSpeed config. scripts/zero3.json for Stage 2.

Outputs

Output Contract
Name Type Description
Full model checkpoint Directory Complete LLaVA model weights (LLM + projector + vision tower config) saved to --output_dir.
config.json File Model configuration including vision tower settings, projector type, and other multimodal config.
Training logs WandB / stdout Loss curves, learning rate schedule, and training metrics.

Key Implementation Details

Modality-Length-Grouped Sampling (llava_trainer.py, lines 135-148)

class LLaVATrainer(Trainer):

    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        if self.train_dataset is None or not has_length(self.train_dataset):
            return None

        if self.args.group_by_modality_length:
            lengths = self.train_dataset.modality_lengths
            return LengthGroupedSampler(
                self.args.train_batch_size,
                world_size=self.args.world_size * self.args.gradient_accumulation_steps,
                lengths=lengths,
                group_by_modality=True,
            )
        else:
            return super()._get_train_sampler()

Separate Projector Learning Rate (llava_trainer.py, lines 150-228)

def create_optimizer(self):
    if self.optimizer is None:
        decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
        decay_parameters = [name for name in decay_parameters if "bias" not in name]
        if self.args.mm_projector_lr is not None:
            projector_parameters = [name for name, _ in opt_model.named_parameters()
                                    if "mm_projector" in name]
            optimizer_grouped_parameters = [
                {   # LLM params WITH weight decay
                    "params": [p for n, p in opt_model.named_parameters()
                               if n in decay_parameters and n not in projector_parameters
                               and p.requires_grad],
                    "weight_decay": self.args.weight_decay,
                },
                {   # LLM params WITHOUT weight decay
                    "params": [p for n, p in opt_model.named_parameters()
                               if n not in decay_parameters and n not in projector_parameters
                               and p.requires_grad],
                    "weight_decay": 0.0,
                },
                {   # Projector params WITH weight decay (separate LR)
                    "params": [p for n, p in opt_model.named_parameters()
                               if n in decay_parameters and n in projector_parameters
                               and p.requires_grad],
                    "weight_decay": self.args.weight_decay,
                    "lr": self.args.mm_projector_lr,
                },
                {   # Projector params WITHOUT weight decay (separate LR)
                    "params": [p for n, p in opt_model.named_parameters()
                               if n not in decay_parameters and n in projector_parameters
                               and p.requires_grad],
                    "weight_decay": 0.0,
                    "lr": self.args.mm_projector_lr,
                },
            ]

Custom Checkpoint Saving (llava_trainer.py, lines 230-255)

def _save_checkpoint(self, model, trial, metrics=None):
    if getattr(self.args, 'tune_mm_mlp_adapter', False):
        # Stage 1: Save only projector weights
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)

        keys_to_match = ['mm_projector', 'vision_resampler']
        weight_to_save = get_mm_adapter_state_maybe_zero_3(
            self.model.named_parameters(), keys_to_match
        )
        if self.args.local_rank == 0 or self.args.local_rank == -1:
            self.model.config.save_pretrained(output_dir)
            torch.save(weight_to_save, os.path.join(output_dir, 'mm_projector.bin'))
    else:
        # Stage 2: Full model save via parent Trainer
        super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment