Implementation:OpenGVLab InternVL LLaVATrainer
| Knowledge Sources | |
|---|---|
| Domains | Training, Multimodal Models, LLaVA |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Custom HuggingFace Trainer subclass for LLaVA that provides multimodal-aware data sampling and specialized checkpointing for vision-language training.
Description
LLaVATrainer extends the HuggingFace Trainer class with two key customizations:
Multimodal-aware data sampling: The _get_train_sampler() override uses a custom LengthGroupedSampler when group_by_modality_length is enabled. This sampler separates multimodal (image+text) and language-only samples, sorts each group by length using get_modality_length_grouped_indices(), then interleaves them in megabatches (world_size x batch_size) to ensure efficient batching. The split_to_even_chunks() helper distributes indices across chunks while balancing total lengths.
DeepSpeed ZeRO-3 compatible checkpointing: The _save_checkpoint() override handles the special case where only the mm_projector adapter is being tuned (tune_mm_mlp_adapter mode). It uses get_mm_adapter_state_maybe_zero_3() to gather parameters from ZeRO-3 partitions via maybe_zero_3(), which calls DeepSpeed's GatheredParameters context manager. Only mm_projector, vision_resampler, embed_tokens, and optionally vision tower position embeddings are saved. The _save() method is similarly overridden to no-op in adapter-only mode.
Usage
Use this trainer when training LLaVA models with multimodal datasets, especially when using DeepSpeed ZeRO-3 and modality-aware batching.
Code Reference
Source Location
- Repository: OpenGVLab_InternVL
- File: internvl_chat_llava/llava/train/llava_trainer.py
- Lines: 1-180
Signature
def maybe_zero_3(param, ignore_status=False, name=None): ...
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): ...
def split_to_even_chunks(indices, lengths, num_chunks): ...
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): ...
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): ...
class LengthGroupedSampler(Sampler):
def __init__(self, batch_size, world_size, lengths=None, generator=None,
group_by_modality=False): ...
class LLaVATrainer(Trainer):
def _get_train_sampler(self): ...
def _save_checkpoint(self, model, trial, metrics=None): ...
def _save(self, output_dir=None, state_dict=None): ...
Import
from llava.train.llava_trainer import LLaVATrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| train_dataset | Dataset | Yes | Training dataset with modality_lengths attribute for grouped sampling |
| args.group_by_modality_length | bool | No | Whether to group samples by modality and length |
| args.tune_mm_mlp_adapter | bool | No | Whether only the mm_projector is being fine-tuned |
| args.tune_vit_pos_embedding | bool | No | Whether to also save ViT position embeddings |
Outputs
| Name | Type | Description |
|---|---|---|
| checkpoints | files | Model checkpoints or mm_projector.bin adapter weights |
Usage Examples
Basic Usage
from llava.train.llava_trainer import LLaVATrainer
trainer = LLaVATrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)
trainer.train()