Implementation:Haotian liu LLaVA Train With LoRA
Appearance
Overview
Concrete tool for running LoRA/QLoRA fine-tuning of a LLaVA model.
Type
API Doc
Description
The train() function with lora_enable=True applies LoRA adapters and trains them. The key code paths are:
- L860-876 -- Creates LoraConfig with auto-detected target_modules from find_all_linear_names() and calls get_peft_model() to wrap the model with LoRA adapters.
- L797-814 -- When bits=4 or bits=8, configures BitsAndBytesConfig for model quantization (QLoRA).
- L847-850 -- For quantized models, calls prepare_model_for_kbit_training() to handle gradient checkpointing compatibility.
- L959-964 -- Instantiates LLaVATrainer and starts training.
- L974-984 -- At save time, extracts and saves LoRA weights and non_lora_trainables.bin separately.
Source
llava/train/train.py:L788-991(train function)llava/train/train.py:L860-876(LoRA setup)llava/train/train.py:L169-182(find_all_linear_names)llava/train/train.py:L130-160(LoRA state dict extraction)
Signature
def train(attn_implementation=None) -> None:
"""Main training entry point.
LoRA-specific CLI args:
--lora_enable True
--lora_r 128 (v1.5 default)
--lora_alpha 256 (v1.5 default)
--lora_dropout 0.05
--lora_bias "none"
--bits 16 (or 4 for QLoRA)
--model_name_or_path pre-trained LLaVA checkpoint
--deepspeed scripts/zero3.json
--mm_projector_lr 2e-5 (separate projector LR)
"""
LoRA Setup Code Path (L860-876)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
LoRA Save Code Path (L974-984)
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), training_args.lora_bias
)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
model.named_parameters()
)
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict,
os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
Import
from llava.train.train import train
Inputs
- Pre-trained LLaVA model -- Base model checkpoint or HuggingFace model ID (--model_name_or_path)
- Training data JSON -- Conversation-format JSON file (--data_path)
- Image folder -- Directory containing training images (--image_folder)
- Vision tower -- CLIP model identifier (--vision_tower)
- DeepSpeed config -- ZeRO stage configuration (--deepspeed)
Outputs
LoRA adapter checkpoint directory containing:
- adapter_config.json -- LoRA configuration (rank, alpha, target modules)
- adapter_model.bin -- LoRA adapter weights (A and B matrices)
- non_lora_trainables.bin -- Non-LoRA trainable weights (mm_projector)
- config.json -- Model configuration
Usage Examples
LoRA Training (v1.5, 13B)
deepspeed llava/train/train_mem.py \
--lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
--deepspeed ./scripts/zero3.json \
--model_name_or_path lmsys/vicuna-13b-v1.5 \
--version v1 \
--data_path ./playground/data/llava_v1_5_mix665k.json \
--image_folder ./playground/data \
--vision_tower openai/clip-vit-large-patch14-336 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir ./checkpoints/llava-v1.5-13b-lora \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 1 \
--learning_rate 2e-4 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--model_max_length 2048 \
--gradient_checkpointing True \
--lazy_preprocess True \
--dataloader_num_workers 4 \
--report_to wandb
QLoRA Training (4-bit)
deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--lora_enable True \
--bits 4 \
--model_name_or_path ./checkpoints/vicuna-v1-3-7b \
--version v1 \
--data_path ./playground/data/llava_instruct_80k.json \
--image_folder /path/to/coco/train2017 \
--vision_tower openai/clip-vit-large-patch14 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-vicuna-v1-3-7b-pretrain/mm_projector.bin \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--bf16 True \
--output_dir ./checkpoints/llava-vicuna-v1-3-7b-finetune_lora \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--learning_rate 2e-5 \
--model_max_length 2048 \
--gradient_checkpointing True \
--lazy_preprocess True
Task-Specific LoRA (from pre-trained LLaVA)
deepspeed llava/train/train_mem.py \
--lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
--deepspeed ./scripts/zero3.json \
--model_name_or_path liuhaotian/llava-v1.5-13b \
--version v1 \
--data_path /path/to/custom_data.json \
--image_folder /path/to/custom_images \
--bf16 True \
--output_dir ./checkpoints/llava-v1.5-13b-task-lora \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--learning_rate 2e-4 \
--model_max_length 2048 \
--gradient_checkpointing True \
--lazy_preprocess True
Metadata
| Field | Value |
|---|---|
| last_updated | 2026-02-13 14:00 GMT |
| source_repo | Haotian_liu_LLaVA |
| commit | 799f5f207c89 |
| type | Implementation (API Doc) |
Related Pages
- implements Principle:Haotian_liu_LLaVA_LoRA_Training
- Environment:Haotian_liu_LLaVA_Python_CUDA_Training_Environment
- Heuristic:Haotian_liu_LLaVA_Flash_Attention_GPU_Requirement
- Heuristic:Haotian_liu_LLaVA_Gradient_Checkpointing_Memory_Optimization
- Heuristic:Haotian_liu_LLaVA_Quantization_MM_Projector_Exclusion
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment