Implementation:OpenGVLab InternVL MultimodalDPOTrainer
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Alignment, Training, Vision_Language |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for preference optimization training of multimodal vision-language models provided by the InternVL training framework.
Description
The MultimodalDPOTrainer extends TRL's DPOTrainer to handle multimodal inputs (images + text). Key modifications:
- concatenated_inputs: Adds pixel_values and image_flags handling, repeating 2x for chosen+rejected concatenation
- concatenated_forward: Passes pixel_values/image_flags through the model forward; computes NLL cross-entropy on chosen tokens
- get_batch_loss_metrics: Supports composite loss types via comma-split (e.g., 'sigmoid,bco_pair') with per-component weights; mixes NLL loss via rpo_alpha
Usage
Used in the MPO training script (internvl_chat_mpo.py). Requires a policy model and a frozen reference model loaded from the same checkpoint.
Code Reference
Source Location
- Repository: InternVL
- File: internvl_chat/internvl/train/trainer_dpo.py
- Lines: L25-302
Signature
class MultimodalDPOTrainer(DPOTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
def concatenated_inputs(
batch: Dict[str, Union[List, torch.LongTensor]],
is_encoder_decoder: bool = False,
is_vision_model: bool = False,
label_pad_token_id: int = -100,
padding_value: int = 0,
device: Optional[torch.device] = None,
) -> Dict[str, torch.LongTensor]:
"""Concatenate chosen and rejected inputs with multimodal data."""
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Forward pass computing log-probs for both chosen and rejected."""
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal['train', 'eval'] = 'train',
):
"""Compute composite loss with multiple loss types and NLL mixing."""
Import
from internvl.train.trainer_dpo import MultimodalDPOTrainer
External Reference
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | InternVLChatModel | Yes | Policy model (trainable) |
| ref_model | InternVLChatModel | Yes | Reference model (frozen) |
| args | DPOConfig | Yes | Training config with loss_type, beta, rpo_alpha |
| train_dataset | Dataset | Yes | DPO dataset with chosen/rejected pairs |
| data_collator | Callable | Yes | dpo_concat_pad_data_collator |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Composite loss: w_sigmoid * DPO + w_bco * BCO + rpo_alpha * NLL |
| metrics | Dict | chosen_rewards, rejected_rewards, per-component losses |
| checkpoints | Files | Model checkpoints at configured intervals |
Usage Examples
MPO Training Setup
from internvl.train.trainer_dpo import MultimodalDPOTrainer
from internvl.patch.pad_data_collator import dpo_concat_pad_data_collator
trainer = MultimodalDPOTrainer(
model=model,
ref_model=ref_model,
args=dpo_config,
train_dataset=dpo_dataset,
tokenizer=tokenizer,
data_collator=dpo_concat_pad_data_collator,
)
trainer.train()
trainer.save_model()
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment