Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:OpenGVLab InternVL LLaVATrainer

From Leeroopedia
Revision as of 16:14, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/OpenGVLab_InternVL_LLaVATrainer.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Training, Multimodal Models, LLaVA
Last Updated 2026-02-07 14:00 GMT

Overview

Custom HuggingFace Trainer subclass for LLaVA that provides multimodal-aware data sampling and specialized checkpointing for vision-language training.

Description

LLaVATrainer extends the HuggingFace Trainer class with two key customizations:

Multimodal-aware data sampling: The _get_train_sampler() override uses a custom LengthGroupedSampler when group_by_modality_length is enabled. This sampler separates multimodal (image+text) and language-only samples, sorts each group by length using get_modality_length_grouped_indices(), then interleaves them in megabatches (world_size x batch_size) to ensure efficient batching. The split_to_even_chunks() helper distributes indices across chunks while balancing total lengths.

DeepSpeed ZeRO-3 compatible checkpointing: The _save_checkpoint() override handles the special case where only the mm_projector adapter is being tuned (tune_mm_mlp_adapter mode). It uses get_mm_adapter_state_maybe_zero_3() to gather parameters from ZeRO-3 partitions via maybe_zero_3(), which calls DeepSpeed's GatheredParameters context manager. Only mm_projector, vision_resampler, embed_tokens, and optionally vision tower position embeddings are saved. The _save() method is similarly overridden to no-op in adapter-only mode.

Usage

Use this trainer when training LLaVA models with multimodal datasets, especially when using DeepSpeed ZeRO-3 and modality-aware batching.

Code Reference

Source Location

Signature

def maybe_zero_3(param, ignore_status=False, name=None): ...
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): ...
def split_to_even_chunks(indices, lengths, num_chunks): ...
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): ...
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): ...

class LengthGroupedSampler(Sampler):
    def __init__(self, batch_size, world_size, lengths=None, generator=None,
                 group_by_modality=False): ...

class LLaVATrainer(Trainer):
    def _get_train_sampler(self): ...
    def _save_checkpoint(self, model, trial, metrics=None): ...
    def _save(self, output_dir=None, state_dict=None): ...

Import

from llava.train.llava_trainer import LLaVATrainer

I/O Contract

Inputs

Name Type Required Description
train_dataset Dataset Yes Training dataset with modality_lengths attribute for grouped sampling
args.group_by_modality_length bool No Whether to group samples by modality and length
args.tune_mm_mlp_adapter bool No Whether only the mm_projector is being fine-tuned
args.tune_vit_pos_embedding bool No Whether to also save ViT position embeddings

Outputs

Name Type Description
checkpoints files Model checkpoints or mm_projector.bin adapter weights

Usage Examples

Basic Usage

from llava.train.llava_trainer import LLaVATrainer

trainer = LLaVATrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)
trainer.train()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment