Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft DeepSpeedExamples Build Dataset

From Leeroopedia


  1. Implementation: Build_Dataset

Metadata

Field Value
Page Type Implementation (Pattern Doc)
Title Build_Dataset
Repository Microsoft/DeepSpeedExamples
Application DeepSpeed-VisualChat
File applications/DeepSpeed-VisualChat/utils/data/builder.py
Lines 23-140
Language Python
Status Active

Overview

Concrete tool for building unified VQA datasets from 13+ sources for DeepSpeed-VisualChat training.

Code Reference

Builder Function (Lines 23-140)

def build_dataset(data_path, data_debug_path, dataset_name, dataset_sample,
                  dataset_concatenate_samples, max_num_image_per_sample, **kwargs):
    if isinstance(dataset_name, list):
        datasets = [build_dataset(data_path, data_debug_path,
                                  dataset_name[i], dataset_sample[i],
                                  dataset_concatenate_samples[i],
                                  max_num_image_per_sample,
                                  **kwargs) for i in range(len(dataset_name))]
        return ConcatDataset(datasets)

    if dataset_name == "aokvqa":
        dataset = AOKVQADataset(data_path, data_debug_path,
                                dataset_concatenate_samples, **kwargs)
    elif dataset_name == "coco_caption":
        dataset = COCOCaptionDataset(data_path, data_debug_path,
                                     dataset_concatenate_samples, **kwargs)
    elif dataset_name == "llava":
        dataset = LlavaDataset(data_path, data_debug_path,
                                dataset_concatenate_samples, **kwargs)
    elif dataset_name == "llava_dial":
        dataset = DialDataset(dataset_name, data_path, data_debug_path,
                               dataset_concatenate_samples, **kwargs)
    elif dataset_name == "llava_otter_blend":
        dataset = LlavaOtterBlendDataset(data_path, data_debug_path,
                                          dataset_concatenate_samples,
                                          followup=False, **kwargs)
    elif dataset_name == "minigpt4":
        dataset = CcSbuAlignDataset(data_path, data_debug_path,
                                     dataset_concatenate_samples, **kwargs)
    elif dataset_name == "ocr_vqa":
        dataset = OCRVQADataset(data_path, data_debug_path,
                                 dataset_concatenate_samples, **kwargs)
    elif dataset_name == "otter_mimicit_cgd":
        dataset = OtterMimicitCgdDataset(data_path, data_debug_path,
                                          dataset_concatenate_samples, **kwargs)
    elif dataset_name == "otter_mimicit_sd":
        dataset = OtterMimicitSdDataset(data_path, data_debug_path,
                                         dataset_concatenate_samples, **kwargs)
    elif dataset_name == "otter_mimicit_sn":
        dataset = OtterMimicitSnDataset(data_path, data_debug_path,
                                         dataset_concatenate_samples,
                                         max_num_image_per_sample, **kwargs)
    elif dataset_name == "otter_mimicit_tvc":
        dataset = OtterMimicitTvcDataset(data_path, data_debug_path,
                                          dataset_concatenate_samples,
                                          max_num_image_per_sample, **kwargs)
    elif dataset_name == "otter_mimicit_vst":
        dataset = OtterMimicitVstDataset(data_path, data_debug_path,
                                          dataset_concatenate_samples,
                                          max_num_image_per_sample, **kwargs)
    elif dataset_name == "sparkles_dialogue":
        dataset = SparklesDialogueDataset(data_path, data_debug_path,
                                           dataset_concatenate_samples, **kwargs)
    else:
        raise NotImplementedError

    # Subsampling if dataset_sample is not "all"
    if dataset_sample != 'all':
        dataset_sample = int(dataset_sample)
        random_indices = np.random.choice(
            len(dataset),
            min(dataset_sample, len(dataset)),
            replace=False)
        subsample_dataset = torch.utils.data.Subset(dataset, random_indices)
        subsample_dataset.collater = dataset.collater
        print_rank_0(
            f"[DATA] Built dataset {dataset_name} "
            f"with {len(subsample_dataset)} samples.")
        return subsample_dataset
    else:
        print_rank_0(
            f"[DATA] Built dataset {dataset_name} "
            f"with all {len(dataset)} samples.")
        return dataset

I/O Contract

Function Signature

def build_dataset(
    data_path,                      # str: root data directory
    data_debug_path,                # str or None: debug output directory
    dataset_name,                   # str or list[str]: dataset name(s)
    dataset_sample,                 # str: "all" or integer count
    dataset_concatenate_samples,    # int or list[int]: QA pairs per sample
    max_num_image_per_sample,       # int: max images per sample
    **kwargs                        # vis_processor, tokenizer, etc.
) -> Dataset
Direction Parameter Type Description
Input data_path str Root directory containing dataset files and images
Input data_debug_path str or None If provided, saves sample debug outputs
Input dataset_name str or list[str] One of 13 supported dataset names, or a list for multi-dataset training
Input dataset_sample str "all" to use all samples, or an integer string for subsampling
Input dataset_concatenate_samples int or list[int] Number of QA pairs to concatenate per data point
Input max_num_image_per_sample int Maximum number of images per training sample (up to 8)
Input (kwargs) vis_processor CLIPImageProcessor Image preprocessing pipeline
Input (kwargs) tokenizer AutoTokenizer Text tokenizer with special tokens
Output (return) Dataset or ConcatDataset PyTorch Dataset producing dicts with image, input_ids, attention_mask, labels, image_num

Output Sample Format

Each sample returned by __getitem__ is a dictionary:

Key Type Shape Description
image list[torch.Tensor] [N, 3, H, W] List of N preprocessed images
input_ids list[int] [seq_len] Token IDs with image placeholders
attention_mask list[int] [seq_len] 1 for real tokens, 0 for padding
labels list[int] [seq_len] Token IDs for answer positions, -100 elsewhere
image_num int scalar Number of images in this sample
instruction str -- Formatted instruction text
answer str -- Ground-truth answer text

Supported Datasets

Name String Dataset Class Receives max_num_image?
aokvqa AOKVQADataset No
coco_caption COCOCaptionDataset No
llava LlavaDataset No
llava_dial DialDataset No
llava_otter_blend LlavaOtterBlendDataset No
minigpt4 CcSbuAlignDataset No
ocr_vqa OCRVQADataset No
otter_mimicit_cgd OtterMimicitCgdDataset No
otter_mimicit_sd OtterMimicitSdDataset No
otter_mimicit_sn OtterMimicitSnDataset Yes
otter_mimicit_tvc OtterMimicitTvcDataset Yes
otter_mimicit_vst OtterMimicitVstDataset Yes
sparkles_dialogue SparklesDialogueDataset No

Note that the otter_mimicit_sn, otter_mimicit_tvc, and otter_mimicit_vst datasets receive max_num_image_per_sample as a positional argument, since these inherently support multi-image scenes.

Supporting Classes

DST.Prompter

The Prompter class formats questions into the DST template:

class Prompter:
    def __call__(self, question, with_image=True, first_message=False,
                 num_images=-1, options=None):
        if with_image:
            res = TEMPLATE["prompt_qa_with_image"].replace(
                DEFAULT_QUESTION_TOKEN, question)
            if num_images >= 1:
                # Replace single image marker with multi-image markers
                ...
        else:
            res = TEMPLATE["prompt_qa_without_image"].replace(
                DEFAULT_QUESTION_TOKEN, question)
        if first_message:
            res = DEFAULT_PROMPT + res
        return res

VQADataset Base Class

All dataset classes inherit from VQADataset, which provides:

  • process_image(ann) -- Loads and preprocesses an image from an annotation dict
  • process_text(ann) -- Formats question/answer text using the DST Prompter
  • tokenize(text) -- Tokenizes instruction + answer, applies label masking
  • merge_all_images(res_list) -- Combines multiple QA pairs into a single multi-image sample with numbered image markers
  • collater(samples) -- Pads a batch of samples to uniform length

Usage Example

In Training Script (training/main.py)

# Expand sample and concatenation args to match dataset count
args.dataset_concatenate_samples = [int(i) for i in args.dataset_concatenate_samples]

dataset = build_dataset(
    args.data_path,
    args.data_debug_path,
    args.dataset_names,              # e.g., ["llava", "coco_caption"]
    args.dataset_samples,            # e.g., ["all", "512"]
    args.dataset_concatenate_samples, # e.g., [1, 1]
    args.max_num_image_per_sample,   # e.g., 8
    vis_processor=image_processor,
    tokenizer=tokenizer,
)

# Shuffle and split
np_rng = np.random.RandomState(seed=args.seed)
dataset = shuffle_dataset(dataset, np_rng)
train_dataset, eval_dataset = split_dataset(dataset, args.data_train_split_ratio)

# Create DataLoader with distributed sampler
train_dataloader = DataLoader(
    train_dataset,
    batch_size=args.per_device_train_batch_size,
    sampler=DistributedSampler(train_dataset, shuffle=True, drop_last=True),
    collate_fn=DataCollatorPadToMaxLen(args.max_seq_len, tokenizer.pad_token_id),
)

Command-Line Multi-Dataset Configuration

deepspeed training/main.py \
    --data_path ./data/ \
    --dataset_names llava coco_caption otter_mimicit_cgd \
    --dataset_samples all 5000 all \
    --dataset_concatenate_samples 1 1 1 \
    --max_num_image_per_sample 8 \
    ...

Dependencies

  • numpy -- Random subsampling
  • torch.utils.data -- Dataset, Subset, ConcatDataset
  • 13 dataset-specific classes imported from sibling modules
  • utils.utils.print_rank_0 -- Rank-aware printing utility

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment