Implementation:Microsoft DeepSpeedExamples Build Dataset
Appearance
- Implementation: Build_Dataset
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (Pattern Doc) |
| Title | Build_Dataset |
| Repository | Microsoft/DeepSpeedExamples |
| Application | DeepSpeed-VisualChat |
| File | applications/DeepSpeed-VisualChat/utils/data/builder.py
|
| Lines | 23-140 |
| Language | Python |
| Status | Active |
Overview
Concrete tool for building unified VQA datasets from 13+ sources for DeepSpeed-VisualChat training.
Code Reference
Builder Function (Lines 23-140)
def build_dataset(data_path, data_debug_path, dataset_name, dataset_sample,
dataset_concatenate_samples, max_num_image_per_sample, **kwargs):
if isinstance(dataset_name, list):
datasets = [build_dataset(data_path, data_debug_path,
dataset_name[i], dataset_sample[i],
dataset_concatenate_samples[i],
max_num_image_per_sample,
**kwargs) for i in range(len(dataset_name))]
return ConcatDataset(datasets)
if dataset_name == "aokvqa":
dataset = AOKVQADataset(data_path, data_debug_path,
dataset_concatenate_samples, **kwargs)
elif dataset_name == "coco_caption":
dataset = COCOCaptionDataset(data_path, data_debug_path,
dataset_concatenate_samples, **kwargs)
elif dataset_name == "llava":
dataset = LlavaDataset(data_path, data_debug_path,
dataset_concatenate_samples, **kwargs)
elif dataset_name == "llava_dial":
dataset = DialDataset(dataset_name, data_path, data_debug_path,
dataset_concatenate_samples, **kwargs)
elif dataset_name == "llava_otter_blend":
dataset = LlavaOtterBlendDataset(data_path, data_debug_path,
dataset_concatenate_samples,
followup=False, **kwargs)
elif dataset_name == "minigpt4":
dataset = CcSbuAlignDataset(data_path, data_debug_path,
dataset_concatenate_samples, **kwargs)
elif dataset_name == "ocr_vqa":
dataset = OCRVQADataset(data_path, data_debug_path,
dataset_concatenate_samples, **kwargs)
elif dataset_name == "otter_mimicit_cgd":
dataset = OtterMimicitCgdDataset(data_path, data_debug_path,
dataset_concatenate_samples, **kwargs)
elif dataset_name == "otter_mimicit_sd":
dataset = OtterMimicitSdDataset(data_path, data_debug_path,
dataset_concatenate_samples, **kwargs)
elif dataset_name == "otter_mimicit_sn":
dataset = OtterMimicitSnDataset(data_path, data_debug_path,
dataset_concatenate_samples,
max_num_image_per_sample, **kwargs)
elif dataset_name == "otter_mimicit_tvc":
dataset = OtterMimicitTvcDataset(data_path, data_debug_path,
dataset_concatenate_samples,
max_num_image_per_sample, **kwargs)
elif dataset_name == "otter_mimicit_vst":
dataset = OtterMimicitVstDataset(data_path, data_debug_path,
dataset_concatenate_samples,
max_num_image_per_sample, **kwargs)
elif dataset_name == "sparkles_dialogue":
dataset = SparklesDialogueDataset(data_path, data_debug_path,
dataset_concatenate_samples, **kwargs)
else:
raise NotImplementedError
# Subsampling if dataset_sample is not "all"
if dataset_sample != 'all':
dataset_sample = int(dataset_sample)
random_indices = np.random.choice(
len(dataset),
min(dataset_sample, len(dataset)),
replace=False)
subsample_dataset = torch.utils.data.Subset(dataset, random_indices)
subsample_dataset.collater = dataset.collater
print_rank_0(
f"[DATA] Built dataset {dataset_name} "
f"with {len(subsample_dataset)} samples.")
return subsample_dataset
else:
print_rank_0(
f"[DATA] Built dataset {dataset_name} "
f"with all {len(dataset)} samples.")
return dataset
I/O Contract
Function Signature
def build_dataset(
data_path, # str: root data directory
data_debug_path, # str or None: debug output directory
dataset_name, # str or list[str]: dataset name(s)
dataset_sample, # str: "all" or integer count
dataset_concatenate_samples, # int or list[int]: QA pairs per sample
max_num_image_per_sample, # int: max images per sample
**kwargs # vis_processor, tokenizer, etc.
) -> Dataset
| Direction | Parameter | Type | Description |
|---|---|---|---|
| Input | data_path |
str |
Root directory containing dataset files and images |
| Input | data_debug_path |
str or None |
If provided, saves sample debug outputs |
| Input | dataset_name |
str or list[str] |
One of 13 supported dataset names, or a list for multi-dataset training |
| Input | dataset_sample |
str |
"all" to use all samples, or an integer string for subsampling
|
| Input | dataset_concatenate_samples |
int or list[int] |
Number of QA pairs to concatenate per data point |
| Input | max_num_image_per_sample |
int |
Maximum number of images per training sample (up to 8) |
| Input (kwargs) | vis_processor |
CLIPImageProcessor |
Image preprocessing pipeline |
| Input (kwargs) | tokenizer |
AutoTokenizer |
Text tokenizer with special tokens |
| Output | (return) | Dataset or ConcatDataset |
PyTorch Dataset producing dicts with image, input_ids, attention_mask, labels, image_num |
Output Sample Format
Each sample returned by __getitem__ is a dictionary:
| Key | Type | Shape | Description |
|---|---|---|---|
image |
list[torch.Tensor] |
[N, 3, H, W] |
List of N preprocessed images |
input_ids |
list[int] |
[seq_len] |
Token IDs with image placeholders |
attention_mask |
list[int] |
[seq_len] |
1 for real tokens, 0 for padding |
labels |
list[int] |
[seq_len] |
Token IDs for answer positions, -100 elsewhere |
image_num |
int |
scalar | Number of images in this sample |
instruction |
str |
-- | Formatted instruction text |
answer |
str |
-- | Ground-truth answer text |
Supported Datasets
| Name String | Dataset Class | Receives max_num_image? |
|---|---|---|
aokvqa |
AOKVQADataset |
No |
coco_caption |
COCOCaptionDataset |
No |
llava |
LlavaDataset |
No |
llava_dial |
DialDataset |
No |
llava_otter_blend |
LlavaOtterBlendDataset |
No |
minigpt4 |
CcSbuAlignDataset |
No |
ocr_vqa |
OCRVQADataset |
No |
otter_mimicit_cgd |
OtterMimicitCgdDataset |
No |
otter_mimicit_sd |
OtterMimicitSdDataset |
No |
otter_mimicit_sn |
OtterMimicitSnDataset |
Yes |
otter_mimicit_tvc |
OtterMimicitTvcDataset |
Yes |
otter_mimicit_vst |
OtterMimicitVstDataset |
Yes |
sparkles_dialogue |
SparklesDialogueDataset |
No |
Note that the otter_mimicit_sn, otter_mimicit_tvc, and otter_mimicit_vst datasets receive max_num_image_per_sample as a positional argument, since these inherently support multi-image scenes.
Supporting Classes
DST.Prompter
The Prompter class formats questions into the DST template:
class Prompter:
def __call__(self, question, with_image=True, first_message=False,
num_images=-1, options=None):
if with_image:
res = TEMPLATE["prompt_qa_with_image"].replace(
DEFAULT_QUESTION_TOKEN, question)
if num_images >= 1:
# Replace single image marker with multi-image markers
...
else:
res = TEMPLATE["prompt_qa_without_image"].replace(
DEFAULT_QUESTION_TOKEN, question)
if first_message:
res = DEFAULT_PROMPT + res
return res
VQADataset Base Class
All dataset classes inherit from VQADataset, which provides:
process_image(ann)-- Loads and preprocesses an image from an annotation dictprocess_text(ann)-- Formats question/answer text using the DST Promptertokenize(text)-- Tokenizes instruction + answer, applies label maskingmerge_all_images(res_list)-- Combines multiple QA pairs into a single multi-image sample with numbered image markerscollater(samples)-- Pads a batch of samples to uniform length
Usage Example
In Training Script (training/main.py)
# Expand sample and concatenation args to match dataset count
args.dataset_concatenate_samples = [int(i) for i in args.dataset_concatenate_samples]
dataset = build_dataset(
args.data_path,
args.data_debug_path,
args.dataset_names, # e.g., ["llava", "coco_caption"]
args.dataset_samples, # e.g., ["all", "512"]
args.dataset_concatenate_samples, # e.g., [1, 1]
args.max_num_image_per_sample, # e.g., 8
vis_processor=image_processor,
tokenizer=tokenizer,
)
# Shuffle and split
np_rng = np.random.RandomState(seed=args.seed)
dataset = shuffle_dataset(dataset, np_rng)
train_dataset, eval_dataset = split_dataset(dataset, args.data_train_split_ratio)
# Create DataLoader with distributed sampler
train_dataloader = DataLoader(
train_dataset,
batch_size=args.per_device_train_batch_size,
sampler=DistributedSampler(train_dataset, shuffle=True, drop_last=True),
collate_fn=DataCollatorPadToMaxLen(args.max_seq_len, tokenizer.pad_token_id),
)
Command-Line Multi-Dataset Configuration
deepspeed training/main.py \
--data_path ./data/ \
--dataset_names llava coco_caption otter_mimicit_cgd \
--dataset_samples all 5000 all \
--dataset_concatenate_samples 1 1 1 \
--max_num_image_per_sample 8 \
...
Dependencies
numpy-- Random subsamplingtorch.utils.data-- Dataset, Subset, ConcatDataset- 13 dataset-specific classes imported from sibling modules
utils.utils.print_rank_0-- Rank-aware printing utility
Related Pages
- Principle:Microsoft_DeepSpeedExamples_Multi_Dataset_VQA_Preparation -- The theoretical basis for multi-dataset VQA preparation
- Implementation:Microsoft_DeepSpeedExamples_Create_DSVL_Model -- The model that consumes the built datasets
- Implementation:Microsoft_DeepSpeedExamples_DeepSpeed_Initialize_VisualChat -- The training loop using these datasets
- Environment:Microsoft_DeepSpeedExamples_VisualChat_Training_Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment