Implementation:Haotian liu LLaVA LLaVATrainer Train
Appearance
| Knowledge Sources | |
|---|---|
| Domains | |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Concrete tool for running Stage 2 visual instruction tuning using LLaVA's custom Trainer class. LLaVATrainer extends HuggingFace's Trainer with modality-length-grouped sampling, separate optimizer parameter groups for the projector, and custom checkpoint saving logic.
Description
LLaVATrainer extends HuggingFace's Trainer with three key modifications tailored for multimodal vision-language training:
- Modality-length-grouped sampling (
_get_train_sampler()) -- Whengroup_by_modality_length=True, returns a customLengthGroupedSamplerthat uses the dataset'smodality_lengthsproperty. This separates image-containing (positive length) and text-only (negative length) samples into distinct mega-batches, reducing padding waste.
- Separate projector learning rate (
create_optimizer()) -- Whenmm_projector_lris set, creates four optimizer parameter groups: LLM parameters with/without weight decay at the base learning rate, and projector parameters with/without weight decay atmm_projector_lr. This enables training the projector at a different rate than the LLM backbone.
- Custom checkpoint saving (
_save_checkpoint()and_save()) -- Whentune_mm_mlp_adapter=True(Stage 1 mode), only projector weights are saved viaget_mm_adapter_state_maybe_zero_3(). For Stage 2, the default Trainer save behavior is used, which leverages DeepSpeed ZeRO-3'sstage3_gather_16bit_weights_on_model_saveto reconstruct the full model.
Usage
Run the Stage 2 finetuning script:
deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path lmsys/vicuna-13b-v1.5 \
--version v1 \
--data_path ./playground/data/llava_v1_5_mix665k.json \
--image_folder ./playground/data \
--vision_tower openai/clip-vit-large-patch14-336 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir ./checkpoints/llava-v1.5-13b \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
Source: scripts/v1_5/finetune.sh
Code Reference
Source Location
- Repository:
https://github.com/haotian-liu/LLaVA - File:
llava/train/llava_trainer.py, lines 133--255 (LLaVATrainerclass) - File:
llava/train/train.py, lines 788--991 (train()function) - Script:
scripts/v1_5/finetune.sh(launcher script)
Signature
class LLaVATrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
"""Returns LengthGroupedSampler when group_by_modality_length=True."""
...
def create_optimizer(self):
"""Creates optimizer with optional separate mm_projector_lr."""
...
def _save_checkpoint(self, model, trial, metrics=None):
"""Saves only projector weights when tune_mm_mlp_adapter=True."""
...
def _save(self, output_dir: Optional[str] = None, state_dict=None):
"""No-op when tune_mm_mlp_adapter=True; full save otherwise."""
...
Import
from llava.train.llava_trainer import LLaVATrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
--pretrain_mm_mlp_adapter |
str (CLI) | Yes | Path to Stage 1 pretrained projector weights (mm_projector.bin).
|
--group_by_modality_length |
bool (CLI) | Yes | Enables modality-length-grouped sampling. Set to True for Stage 2.
|
--mm_projector_lr |
float (CLI) | No | Optional separate learning rate for the projector. If not set, the projector uses the same LR as the LLM. |
--version |
str (CLI) | Yes | Conversation template. v1 for Stage 2 finetuning (multi-turn Vicuna format).
|
--image_aspect_ratio |
str (CLI) | Yes | pad for Stage 2 (preserves aspect ratio via padding).
|
--data_path |
str (CLI) | Yes | Path to 665K instruction tuning data (llava_v1_5_mix665k.json).
|
--num_train_epochs |
int (CLI) | Yes | Number of training epochs. 1 for Stage 2.
|
--per_device_train_batch_size |
int (CLI) | Yes | Per-GPU batch size. 16 for Stage 2.
|
--learning_rate |
float (CLI) | Yes | Base learning rate. 2e-5 for Stage 2.
|
--deepspeed |
str (CLI) | Yes | Path to DeepSpeed config. scripts/zero3.json for Stage 2.
|
Outputs
| Name | Type | Description |
|---|---|---|
| Full model checkpoint | Directory | Complete LLaVA model weights (LLM + projector + vision tower config) saved to --output_dir.
|
config.json |
File | Model configuration including vision tower settings, projector type, and other multimodal config. |
| Training logs | WandB / stdout | Loss curves, learning rate schedule, and training metrics. |
Key Implementation Details
Modality-Length-Grouped Sampling (llava_trainer.py, lines 135-148)
class LLaVATrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
self.args.train_batch_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
lengths=lengths,
group_by_modality=True,
)
else:
return super()._get_train_sampler()
Separate Projector Learning Rate (llava_trainer.py, lines 150-228)
def create_optimizer(self):
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
if self.args.mm_projector_lr is not None:
projector_parameters = [name for name, _ in opt_model.named_parameters()
if "mm_projector" in name]
optimizer_grouped_parameters = [
{ # LLM params WITH weight decay
"params": [p for n, p in opt_model.named_parameters()
if n in decay_parameters and n not in projector_parameters
and p.requires_grad],
"weight_decay": self.args.weight_decay,
},
{ # LLM params WITHOUT weight decay
"params": [p for n, p in opt_model.named_parameters()
if n not in decay_parameters and n not in projector_parameters
and p.requires_grad],
"weight_decay": 0.0,
},
{ # Projector params WITH weight decay (separate LR)
"params": [p for n, p in opt_model.named_parameters()
if n in decay_parameters and n in projector_parameters
and p.requires_grad],
"weight_decay": self.args.weight_decay,
"lr": self.args.mm_projector_lr,
},
{ # Projector params WITHOUT weight decay (separate LR)
"params": [p for n, p in opt_model.named_parameters()
if n not in decay_parameters and n in projector_parameters
and p.requires_grad],
"weight_decay": 0.0,
"lr": self.args.mm_projector_lr,
},
]
Custom Checkpoint Saving (llava_trainer.py, lines 230-255)
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
# Stage 1: Save only projector weights
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
keys_to_match = ['mm_projector', 'vision_resampler']
weight_to_save = get_mm_adapter_state_maybe_zero_3(
self.model.named_parameters(), keys_to_match
)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, 'mm_projector.bin'))
else:
# Stage 2: Full model save via parent Trainer
super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
Related Pages
- Principle:Haotian_liu_LLaVA_Visual_Instruction_Tuning
- Environment:Haotian_liu_LLaVA_Python_CUDA_Training_Environment
- Heuristic:Haotian_liu_LLaVA_Flash_Attention_GPU_Requirement
- Heuristic:Haotian_liu_LLaVA_Gradient_Checkpointing_Memory_Optimization
- Heuristic:Haotian_liu_LLaVA_Use_Cache_Training_Inference_Toggle
- Heuristic:Haotian_liu_LLaVA_Image_Aspect_Ratio_Padding_Strategy
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment