Implementation:Volcengine Verl Extract Multi Modal Inputs
| Field | Value |
|---|---|
| Knowledge Sources | verl source code, model utility module |
| Domains | VLM Actor Forward, Multimodal Input Processing, Batch Extraction |
| Last Updated | 2026-02-07 |
Overview
Description
extract_multi_modal_inputs is a utility function that extracts and collects multimodal input tensors (such as pixel_values, image_grid_thw, video_grid_thw) from a batch of data items and concatenates them into a single dict of tensors ready for model consumption.
This function is used in the VLM actor forward pass to prepare the multimodal keyword arguments that are passed alongside input_ids and attention_mask to the model's forward method.
The function handles several important cases:
- Mixed batches -- Supports batches containing both pure-text and multimodal samples. If a sample's multimodal inputs are
None, it is simply skipped. - Selective indexing -- An optional
indicesparameter allows extracting inputs for only a subset of the batch (used during micro-batching). - MiniCPM-O compatibility -- When
image_boundis present in the inputs, the function returns lists of tensors instead of concatenated tensors, which is required by the MiniCPM-O model architecture. - Standard case -- For most VLMs, tensors with the same key are concatenated along dimension 0 using
torch.cat.
Usage
Called during the actor's forward pass when processing VLM training batches. The returned dict is unpacked as keyword arguments to the model's forward method.
Code Reference
| Field | Value |
|---|---|
| Source Location | verl/utils/model.py, Lines 696-738
|
| Signature | list[torch.Tensor]] |
| Import | from verl.utils.model import extract_multi_modal_inputs
|
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
batch_data |
list[dict[str, torch.Tensor]] |
A list of per-sample dicts, each containing multimodal tensors. Keys vary by model (e.g., "pixel_values", "image_grid_thw", "video_grid_thw", "image_bound"). May contain None entries for pure-text samples.
|
indices |
Optional[list[int]] |
If provided, only extract multimodal inputs at these indices within the batch. Used for micro-batch slicing. |
Outputs
| Return | Type | Description |
|---|---|---|
| multi_modal_inputs | list[torch.Tensor]] | A dict mapping tensor key names to concatenated tensors. For standard VLMs, values are torch.Tensor (concatenated along dim 0). For MiniCPM-O (when image_bound is present), values are list[torch.Tensor].
|
Common output keys:
| Key | Shape | Description |
|---|---|---|
pixel_values |
(N, C, H, W) or (N, num_patches, hidden_dim) |
Processed image pixel values, format depends on the model's image processor. |
image_grid_thw |
(N, 3) |
Temporal, height, width grid dimensions for each image (used by Qwen2-VL). |
video_grid_thw |
(N, 3) |
Grid dimensions for video frames (if applicable). |
image_bound |
varies | Image boundary indices (MiniCPM-O specific). |
Usage Examples
Basic extraction in VLM actor forward pass:
from verl.utils.model import extract_multi_modal_inputs
# batch_data is a list of per-sample multimodal input dicts
# from the DataProto batch
batch_data = [
{"pixel_values": pixel_tensor_0, "image_grid_thw": grid_tensor_0},
None, # pure text sample, no multimodal data
{"pixel_values": pixel_tensor_2, "image_grid_thw": grid_tensor_2},
]
# Extract and concatenate all multimodal inputs
multi_modal_inputs = extract_multi_modal_inputs(batch_data)
# multi_modal_inputs = {
# "pixel_values": torch.cat([pixel_tensor_0, pixel_tensor_2], dim=0),
# "image_grid_thw": torch.cat([grid_tensor_0, grid_tensor_2], dim=0),
# }
# Pass to model forward
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**multi_modal_inputs,
)
Selective extraction with indices for micro-batching:
from verl.utils.model import extract_multi_modal_inputs
# Only extract multimodal inputs for specific batch indices
micro_batch_indices = [0, 2, 5]
multi_modal_inputs = extract_multi_modal_inputs(
batch_data=full_batch_data,
indices=micro_batch_indices,
)
# Use in micro-batch forward pass
output = model(
input_ids=micro_batch_input_ids,
attention_mask=micro_batch_attention_mask,
**multi_modal_inputs,
)
Handling mixed text and multimodal batches:
from verl.utils.model import extract_multi_modal_inputs
# In a mixed batch, some samples have no multimodal data
batch_data = [
{"pixel_values": img_tensor}, # VLM sample
None, # text-only sample
None, # text-only sample
{"pixel_values": img_tensor2}, # VLM sample
]
multi_modal_inputs = extract_multi_modal_inputs(batch_data)
if multi_modal_inputs:
# Only pass multimodal inputs if any exist
output = model(input_ids=ids, **multi_modal_inputs)
else:
# Pure text batch
output = model(input_ids=ids)