Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Haotian liu LLaVA Train With LoRA

From Leeroopedia

Overview

Concrete tool for running LoRA/QLoRA fine-tuning of a LLaVA model.

Type

API Doc

Description

The train() function with lora_enable=True applies LoRA adapters and trains them. The key code paths are:

  • L860-876 -- Creates LoraConfig with auto-detected target_modules from find_all_linear_names() and calls get_peft_model() to wrap the model with LoRA adapters.
  • L797-814 -- When bits=4 or bits=8, configures BitsAndBytesConfig for model quantization (QLoRA).
  • L847-850 -- For quantized models, calls prepare_model_for_kbit_training() to handle gradient checkpointing compatibility.
  • L959-964 -- Instantiates LLaVATrainer and starts training.
  • L974-984 -- At save time, extracts and saves LoRA weights and non_lora_trainables.bin separately.

Source

  • llava/train/train.py:L788-991 (train function)
  • llava/train/train.py:L860-876 (LoRA setup)
  • llava/train/train.py:L169-182 (find_all_linear_names)
  • llava/train/train.py:L130-160 (LoRA state dict extraction)

Signature

def train(attn_implementation=None) -> None:
    """Main training entry point.

    LoRA-specific CLI args:
        --lora_enable True
        --lora_r 128          (v1.5 default)
        --lora_alpha 256      (v1.5 default)
        --lora_dropout 0.05
        --lora_bias "none"
        --bits 16             (or 4 for QLoRA)
        --model_name_or_path  pre-trained LLaVA checkpoint
        --deepspeed           scripts/zero3.json
        --mm_projector_lr     2e-5 (separate projector LR)
    """

LoRA Setup Code Path (L860-876)

if training_args.lora_enable:
    from peft import LoraConfig, get_peft_model
    lora_config = LoraConfig(
        r=training_args.lora_r,
        lora_alpha=training_args.lora_alpha,
        target_modules=find_all_linear_names(model),
        lora_dropout=training_args.lora_dropout,
        bias=training_args.lora_bias,
        task_type="CAUSAL_LM",
    )
    if training_args.bits == 16:
        if training_args.bf16:
            model.to(torch.bfloat16)
        if training_args.fp16:
            model.to(torch.float16)
    rank0_print("Adding LoRA adapters...")
    model = get_peft_model(model, lora_config)

LoRA Save Code Path (L974-984)

if training_args.lora_enable:
    state_dict = get_peft_state_maybe_zero_3(
        model.named_parameters(), training_args.lora_bias
    )
    non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
        model.named_parameters()
    )
    if training_args.local_rank == 0 or training_args.local_rank == -1:
        model.config.save_pretrained(training_args.output_dir)
        model.save_pretrained(training_args.output_dir, state_dict=state_dict)
        torch.save(non_lora_state_dict,
                   os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))

Import

from llava.train.train import train

Inputs

  • Pre-trained LLaVA model -- Base model checkpoint or HuggingFace model ID (--model_name_or_path)
  • Training data JSON -- Conversation-format JSON file (--data_path)
  • Image folder -- Directory containing training images (--image_folder)
  • Vision tower -- CLIP model identifier (--vision_tower)
  • DeepSpeed config -- ZeRO stage configuration (--deepspeed)

Outputs

LoRA adapter checkpoint directory containing:

  • adapter_config.json -- LoRA configuration (rank, alpha, target modules)
  • adapter_model.bin -- LoRA adapter weights (A and B matrices)
  • non_lora_trainables.bin -- Non-LoRA trainable weights (mm_projector)
  • config.json -- Model configuration

Usage Examples

LoRA Training (v1.5, 13B)

deepspeed llava/train/train_mem.py \
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
    --deepspeed ./scripts/zero3.json \
    --model_name_or_path lmsys/vicuna-13b-v1.5 \
    --version v1 \
    --data_path ./playground/data/llava_v1_5_mix665k.json \
    --image_folder ./playground/data \
    --vision_tower openai/clip-vit-large-patch14-336 \
    --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir ./checkpoints/llava-v1.5-13b-lora \
    --num_train_epochs 1 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --learning_rate 2e-4 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --lazy_preprocess True \
    --dataloader_num_workers 4 \
    --report_to wandb

QLoRA Training (4-bit)

deepspeed llava/train/train_mem.py \
    --deepspeed ./scripts/zero2.json \
    --lora_enable True \
    --bits 4 \
    --model_name_or_path ./checkpoints/vicuna-v1-3-7b \
    --version v1 \
    --data_path ./playground/data/llava_instruct_80k.json \
    --image_folder /path/to/coco/train2017 \
    --vision_tower openai/clip-vit-large-patch14 \
    --pretrain_mm_mlp_adapter ./checkpoints/llava-vicuna-v1-3-7b-pretrain/mm_projector.bin \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --bf16 True \
    --output_dir ./checkpoints/llava-vicuna-v1-3-7b-finetune_lora \
    --num_train_epochs 1 \
    --per_device_train_batch_size 16 \
    --learning_rate 2e-5 \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --lazy_preprocess True

Task-Specific LoRA (from pre-trained LLaVA)

deepspeed llava/train/train_mem.py \
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
    --deepspeed ./scripts/zero3.json \
    --model_name_or_path liuhaotian/llava-v1.5-13b \
    --version v1 \
    --data_path /path/to/custom_data.json \
    --image_folder /path/to/custom_images \
    --bf16 True \
    --output_dir ./checkpoints/llava-v1.5-13b-task-lora \
    --num_train_epochs 1 \
    --per_device_train_batch_size 16 \
    --learning_rate 2e-4 \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --lazy_preprocess True

Metadata

Field Value
last_updated 2026-02-13 14:00 GMT
source_repo Haotian_liu_LLaVA
commit 799f5f207c89
type Implementation (API Doc)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment