Implementation:OpenGVLab InternVL LLaVA Training Script
| Knowledge Sources | |
|---|---|
| Domains | Model Training, Multimodal Learning, LLaVA |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
This module is the main training script for LLaVA models, handling both pretraining and supervised fine-tuning stages with support for LoRA, quantization, and DeepSpeed.
Description
The train.py file is the primary training entry point for all standard LLaVA configurations in the InternVL repository. It defines three core dataclass configurations:
- ModelArguments: Controls model path, version, vision tower, projector type, and freezing options (freeze_backbone, tune_mm_mlp_adapter, tune_vit_pos_embedding)
- DataArguments: Specifies data path, image folder, aspect ratio handling, and grid pinpoints
- TrainingArguments: Extends HuggingFace TrainingArguments with LoRA parameters (lora_r, lora_alpha, lora_dropout), quantization settings (bits, quant_type, double_quant), and modality grouping
The LazySupervisedDataset loads conversation-format JSON data lazily, processes images on-the-fly with optional padding to square, and tokenizes conversations with proper label masking using IGNORE_INDEX for system and user turns. Multiple conversation preprocessing functions handle different formats: preprocess_v1 (Vicuna-style two-separator), preprocess_llama_2 (LLaMA-2 with [/INST] separator), preprocess_mpt (MPT-style), and preprocess_plain (simple image-caption pairs).
The train() function orchestrates the full pipeline: parsing arguments, loading models with optional BitsAndBytes quantization (4-bit/8-bit), injecting LoRA adapters via PEFT, initializing vision modules, configuring parameter freezing strategies, and launching training via LLaVATrainer. Safe model saving is handled through safe_save_model_for_hf_trainer which properly gathers DeepSpeed ZeRO-3 partitioned parameters.
Usage
Use this script to train LLaVA models from scratch or fine-tune them on custom datasets. It supports full training, LoRA fine-tuning, adapter-only pretraining, and quantized training.
Code Reference
Source Location
- Repository: OpenGVLab_InternVL
- File: internvl_chat_llava/llava/train/train.py
- Lines: 1-993
Signature
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=False)
tune_mm_mlp_adapter: bool = field(default=False)
vision_tower: Optional[str] = field(default=None)
@dataclass
class DataArguments:
data_path: str = field(default=None)
lazy_preprocess: bool = False
image_folder: Optional[str] = field(default=None)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
lora_enable: bool = False
bits: int = field(default=16)
class LazySupervisedDataset(Dataset):
def __init__(self, data_path, tokenizer, data_args): ...
def __getitem__(self, i) -> Dict[str, torch.Tensor]: ...
def train(attn_implementation=None): ...
Import
from llava.train.train import train, ModelArguments, DataArguments, TrainingArguments
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_name_or_path | str | Yes | HuggingFace model path or local checkpoint |
| data_path | str | Yes | Path to training data JSON file |
| vision_tower | str | No | Vision encoder model name/path |
| output_dir | str | Yes | Directory to save trained model |
| lora_enable | bool | No | Enable LoRA adapter training (default: False) |
| bits | int | No | Quantization bits: 4, 8, or 16 (default: 16) |
| tune_mm_mlp_adapter | bool | No | Only tune the multimodal projector (default: False) |
| freeze_backbone | bool | No | Freeze the language model backbone (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| Model checkpoint | Directory | Saved model weights, config, and tokenizer |
| LoRA weights | .bin files | LoRA adapter state dict and non-LoRA trainables (when lora_enable=True) |
| mm_projector.bin | .bin file | Multimodal projector weights (when tune_mm_mlp_adapter=True) |
Usage Examples
Basic Usage
# Full fine-tuning
# deepspeed llava/train/train.py \
# --model_name_or_path lmsys/vicuna-7b-v1.5 \
# --vision_tower openai/clip-vit-large-patch14-336 \
# --data_path ./data/llava_instruct_150k.json \
# --output_dir ./checkpoints/llava-v1.5-7b \
# --bf16 True --deepspeed ./scripts/zero3.json
# LoRA fine-tuning
# deepspeed llava/train/train.py \
# --lora_enable True --lora_r 128 --lora_alpha 256 \
# --model_name_or_path lmsys/vicuna-7b-v1.5 \
# --data_path ./data/llava_instruct_150k.json \
# --output_dir ./checkpoints/llava-v1.5-7b-lora