Implementation:FlagOpen FlagEmbedding BGE VL Flag Dataset
| Knowledge Sources | |
|---|---|
| Domains | Vision-Language, Dataset, Multimodal Data Loading |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
PyTorch dataset and collator classes for loading multimodal image-text data for BGE-VL training and evaluation.
Description
This module provides specialized dataset classes for handling multimodal (image + text) data in the BGE-VL framework. It includes MMIT_Dataset for loading image-text pairs with CLIP image preprocessing, Image_Dataset for loading image-only data, corresponding collators (MMIT_Collator, Image_Collator) for batch processing, and automatic file extension handling for image paths. The datasets integrate with CLIP's image processor for consistent preprocessing and support flexible image ID formats with automatic extension detection.
Usage
Use these classes when training or evaluating BGE-VL multimodal embedding models, creating dataloaders for image-text retrieval tasks, and preprocessing images using CLIP image processor. The MMIT_Dataset is used for paired image-text data, while Image_Dataset is used for encoding image-only corpora.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/BGE_VL/eval/flag_dataset.py
- Lines: 1-109
Signature
class MMIT_Dataset(Dataset):
def __init__(self, captions, image_ids, image_dir, image_processor) -> None:
"""Dataset for multimodal image-text pairs"""
def __getitem__(self, item):
"""Returns (caption, processed_image)"""
class MMIT_Collator:
def __init__(self, tokenizer, caption_max_len):
"""Collator for batching image-text pairs"""
def __call__(self, features):
"""Returns (caption_tokens, image_tensor)"""
class Image_Dataset(Dataset):
def __init__(self, image_ids, image_dir, image_processor) -> None:
"""Dataset for image-only data"""
class Image_Collator:
def __init__(self, tokenizer, caption_max_len):
"""Collator for batching images"""
def __call__(self, features):
"""Returns image_tensor"""
Import
from flag_dataset import MMIT_Dataset, MMIT_Collator, Image_Dataset, Image_Collator
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| captions | List[str] | Yes | Text captions/descriptions |
| image_ids | List[str] | Yes | Image filenames or paths |
| image_dir | str | Yes | Directory containing images |
| image_processor | CLIPImageProcessor | Yes | CLIP image processor instance |
| tokenizer | PreTrainedTokenizer | Yes | Text tokenizer for collator |
| caption_max_len | int | Yes | Maximum caption length |
Outputs
| Name | Type | Description |
|---|---|---|
| caption_tokens | BatchEncoding | Tokenized and padded captions |
| image_tensor | torch.Tensor | Batch of preprocessed images (B, C, H, W) |
Usage Examples
# Example 1: Creating multimodal dataset
from flag_dataset import MMIT_Dataset, MMIT_Collator
from transformers import AutoTokenizer, CLIPImageProcessor
from torch.utils.data import DataLoader
# Initialize processors
tokenizer = AutoTokenizer.from_pretrained("BAAI/BGE-VL-large")
image_processor = CLIPImageProcessor.from_pretrained("BAAI/BGE-VL-large")
# Create dataset
captions = ["A dog playing in the park", "A cat sleeping on couch"]
image_ids = ["img1.jpg", "img2.jpg"]
image_dir = "/path/to/images"
dataset = MMIT_Dataset(
captions=captions,
image_ids=image_ids,
image_dir=image_dir,
image_processor=image_processor
)
# Create collator and dataloader
collator = MMIT_Collator(tokenizer, caption_max_len=77)
dataloader = DataLoader(
dataset,
batch_size=32,
collate_fn=collator,
num_workers=4
)
# Iterate
for caption_tokens, image_tensor in dataloader:
print(caption_tokens['input_ids'].shape) # (batch, seq_len)
print(image_tensor.shape) # (batch, 3, 224, 224)
# Example 2: Image-only dataset
from flag_dataset import Image_Dataset, Image_Collator
image_ids = ["img1.jpg", "img2.jpg", "img3.jpg"]
image_dataset = Image_Dataset(
image_ids=image_ids,
image_dir="/path/to/images",
image_processor=image_processor
)
image_collator = Image_Collator(tokenizer, caption_max_len=77)
image_loader = DataLoader(
image_dataset,
batch_size=64,
collate_fn=image_collator
)
for images in image_loader:
print(images.shape) # (batch, 3, 224, 224)