Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:OpenGVLab InternVL MultimodalDPOTrainer

From Leeroopedia


Knowledge Sources
Domains Alignment, Training, Vision_Language
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for preference optimization training of multimodal vision-language models provided by the InternVL training framework.

Description

The MultimodalDPOTrainer extends TRL's DPOTrainer to handle multimodal inputs (images + text). Key modifications:

  • concatenated_inputs: Adds pixel_values and image_flags handling, repeating 2x for chosen+rejected concatenation
  • concatenated_forward: Passes pixel_values/image_flags through the model forward; computes NLL cross-entropy on chosen tokens
  • get_batch_loss_metrics: Supports composite loss types via comma-split (e.g., 'sigmoid,bco_pair') with per-component weights; mixes NLL loss via rpo_alpha

Usage

Used in the MPO training script (internvl_chat_mpo.py). Requires a policy model and a frozen reference model loaded from the same checkpoint.

Code Reference

Source Location

  • Repository: InternVL
  • File: internvl_chat/internvl/train/trainer_dpo.py
  • Lines: L25-302

Signature

class MultimodalDPOTrainer(DPOTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @staticmethod
    def concatenated_inputs(
        batch: Dict[str, Union[List, torch.LongTensor]],
        is_encoder_decoder: bool = False,
        is_vision_model: bool = False,
        label_pad_token_id: int = -100,
        padding_value: int = 0,
        device: Optional[torch.device] = None,
    ) -> Dict[str, torch.LongTensor]:
        """Concatenate chosen and rejected inputs with multimodal data."""

    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Forward pass computing log-probs for both chosen and rejected."""

    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal['train', 'eval'] = 'train',
    ):
        """Compute composite loss with multiple loss types and NLL mixing."""

Import

from internvl.train.trainer_dpo import MultimodalDPOTrainer

External Reference

I/O Contract

Inputs

Name Type Required Description
model InternVLChatModel Yes Policy model (trainable)
ref_model InternVLChatModel Yes Reference model (frozen)
args DPOConfig Yes Training config with loss_type, beta, rpo_alpha
train_dataset Dataset Yes DPO dataset with chosen/rejected pairs
data_collator Callable Yes dpo_concat_pad_data_collator

Outputs

Name Type Description
loss torch.Tensor Composite loss: w_sigmoid * DPO + w_bco * BCO + rpo_alpha * NLL
metrics Dict chosen_rewards, rejected_rewards, per-component losses
checkpoints Files Model checkpoints at configured intervals

Usage Examples

MPO Training Setup

from internvl.train.trainer_dpo import MultimodalDPOTrainer
from internvl.patch.pad_data_collator import dpo_concat_pad_data_collator

trainer = MultimodalDPOTrainer(
    model=model,
    ref_model=ref_model,
    args=dpo_config,
    train_dataset=dpo_dataset,
    tokenizer=tokenizer,
    data_collator=dpo_concat_pad_data_collator,
)

trainer.train()
trainer.save_model()

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment