Implementation:Haotian liu LLaVA Training Arguments LoRA
Overview
Concrete tool for configuring LoRA and QLoRA training parameters in LLaVA's TrainingArguments dataclass.
Type
API Doc
Description
LLaVA extends HuggingFace's transformers.TrainingArguments with LoRA-specific fields. These fields control whether LoRA is enabled, the rank, alpha scaling factor, dropout rate, and quantization bits. The find_all_linear_names() helper auto-detects target modules by scanning the model's named modules for nn.Linear layers while excluding multimodal components.
When lora_enable=True, the training script creates a LoraConfig using these parameters and applies it via get_peft_model(). When bits=4 or bits=8, quantization is applied to the base model before LoRA adapter injection.
Source
llava/train/train.py:L79-113(TrainingArguments dataclass)llava/train/train.py:L169-182(find_all_linear_names function)
Signature
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
freeze_mm_mlp_adapter: bool = field(default=False)
mpt_attn_impl: Optional[str] = field(default="triton")
model_max_length: int = field(default=512)
double_quant: bool = field(default=True)
quant_type: str = field(default="nf4")
bits: int = field(default=16)
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
mm_projector_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
def find_all_linear_names(model) -> List[str]:
"""Returns names of all nn.Linear modules except mm_projector,
vision_tower, vision_resampler, and lm_head."""
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names:
lora_module_names.remove('lm_head')
return list(lora_module_names)
Import
from llava.train.train import TrainingArguments
from llava.train.train import find_all_linear_names
Parameter Reference
| Parameter | Type | Default | Description |
|---|---|---|---|
| lora_enable | bool | False | Enable LoRA adapter injection |
| lora_r | int | 64 | LoRA rank (v1.5 uses 128) |
| lora_alpha | int | 16 | LoRA alpha scaling (v1.5 uses 256) |
| lora_dropout | float | 0.05 | Dropout applied to LoRA layers |
| lora_weight_path | str | "" | Path to pre-trained LoRA weights (for resume) |
| lora_bias | str | "none" | Bias training strategy: "none", "all", or "lora_only" |
| bits | int | 16 | Quantization bits: 4 (QLoRA), 8 (8-bit), 16 (full precision) |
| mm_projector_lr | Optional[float] | None | Separate learning rate for mm_projector |
| double_quant | bool | True | Enable double quantization (QLoRA) |
| quant_type | str | "nf4" | Quantization type: "nf4" or "fp4" |
Inputs
CLI arguments or dataclass field values passed via HuggingFace's HfArgumentParser.
Outputs
Configured TrainingArguments namespace consumed by the training loop.
Usage Examples
LoRA Fine-tuning (v1.5 defaults)
From scripts/v1_5/finetune_lora.sh:
deepspeed llava/train/train_mem.py \
--lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
--deepspeed ./scripts/zero3.json \
--model_name_or_path lmsys/vicuna-13b-v1.5 \
--version v1 \
--data_path ./playground/data/llava_v1_5_mix665k.json \
--image_folder ./playground/data \
--vision_tower openai/clip-vit-large-patch14-336 \
--bf16 True \
--output_dir ./checkpoints/llava-v1.5-13b-lora \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--learning_rate 2e-4 \
--model_max_length 2048 \
--gradient_checkpointing True
QLoRA Fine-tuning (4-bit)
From scripts/finetune_qlora.sh:
deepspeed llava/train/train_mem.py \
--lora_enable True \
--bits 4 \
--deepspeed ./scripts/zero2.json \
--model_name_or_path ./checkpoints/vicuna-v1-3-7b \
--version v1 \
--data_path ./playground/data/llava_instruct_80k.json \
--image_folder /path/to/coco/train2017 \
--vision_tower openai/clip-vit-large-patch14 \
--bf16 True \
--output_dir ./checkpoints/llava-vicuna-v1-3-7b-finetune_lora \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--learning_rate 2e-5
Metadata
| Field | Value |
|---|---|
| last_updated | 2026-02-13 14:00 GMT |
| source_repo | Haotian_liu_LLaVA |
| commit | 799f5f207c89 |
| type | Implementation (API Doc) |