Implementation:Huggingface Diffusers DreamBooth Dataset Class
| Knowledge Sources | |
|---|---|
| Domains | |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
The DreamBoothDataset class and its companion collate_fn function implement the paired instance-class data loading pipeline for DreamBooth training. The dataset handles image loading, EXIF-aware preprocessing, prompt tokenization, and instance-class pairing, while the collate function concatenates instance and class samples for efficient single-pass training.
Description
The DreamBoothDataset extends PyTorch's Dataset class. On initialization, it:
- Enumerates all image files in
instance_data_root. - Optionally enumerates class images in
class_data_root. - Sets the dataset length to
max(num_instance_images, num_class_images). - Builds an image transform pipeline: Resize, CenterCrop/RandomCrop, ToTensor, Normalize.
Each __getitem__ call returns a dictionary containing:
instance_images-- Preprocessed instance image tensor.instance_prompt_ids-- Tokenized instance prompt (or pre-computed encoder hidden states).class_images-- Preprocessed class image tensor (when class data is provided).class_prompt_ids-- Tokenized class prompt (when class data is provided).
The companion collate_fn handles batching by stacking pixel values and concatenating token IDs. When with_prior_preservation=True, class examples are appended after instance examples in the batch, doubling the effective batch size.
Usage
from torch.utils.data import DataLoader
train_dataset = DreamBoothDataset(
instance_data_root="./my_dog",
instance_prompt="a photo of sks dog",
tokenizer=tokenizer,
class_data_root="./dog_class",
class_prompt="a photo of dog",
class_num=200,
size=512,
center_crop=False,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=4,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, with_prior_preservation=True),
num_workers=0,
)
Code Reference
Source Location
- Repository:
huggingface/diffusers - File:
examples/dreambooth/train_dreambooth_lora.py(lines 568--701)
Signature
class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""
def __init__(
self,
instance_data_root,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
class_num=None,
size=512,
center_crop=False,
encoder_hidden_states=None,
class_prompt_encoder_hidden_states=None,
tokenizer_max_length=None,
):
...
def __len__(self) -> int:
return self._length
def __getitem__(self, index) -> dict:
...
def collate_fn(examples, with_prior_preservation=False) -> dict:
"""
Collate function that concatenates instance and class examples
for prior preservation training.
"""
...
Import
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from PIL.ImageOps import exif_transpose
I/O Contract
Inputs (DreamBoothDataset.__init__)
| Name | Type | Default | Description |
|---|---|---|---|
instance_data_root |
Path | (required) | Path to directory containing instance (subject) images. |
instance_prompt |
str | (required) | Text prompt with identifier token for instance images. |
tokenizer |
PreTrainedTokenizer | (required) | Tokenizer for encoding text prompts into token IDs. |
class_data_root |
Path | None |
Path to directory containing class-prior images. |
class_prompt |
str | None |
Generic class prompt for class-prior images. |
class_num |
int | None |
Maximum number of class images to use. If None, uses all available.
|
size |
int | 512 |
Target resolution for image preprocessing (resize and crop). |
center_crop |
bool | False |
Whether to use center crop (True) or random crop (False).
|
encoder_hidden_states |
Tensor | None |
Pre-computed text encoder hidden states for instance prompt (bypasses tokenization). |
class_prompt_encoder_hidden_states |
Tensor | None |
Pre-computed text encoder hidden states for class prompt. |
tokenizer_max_length |
int | None |
Maximum tokenizer sequence length. Defaults to the tokenizer's own max length. |
Outputs (__getitem__)
| Name | Type | Description |
|---|---|---|
instance_images |
Tensor [C, H, W] | Preprocessed instance image, normalized to [-1, 1].
|
instance_prompt_ids |
Tensor [seq_len] | Tokenized instance prompt IDs (or pre-computed hidden states). |
instance_attention_mask |
Tensor [seq_len] | Attention mask for instance prompt tokens (when not using pre-computed embeddings). |
class_images |
Tensor [C, H, W] | Preprocessed class image (only when class_data_root is provided).
|
class_prompt_ids |
Tensor [seq_len] | Tokenized class prompt IDs (only when class data is provided). |
class_attention_mask |
Tensor [seq_len] | Attention mask for class prompt tokens. |
Outputs (collate_fn)
| Name | Type | Description |
|---|---|---|
pixel_values |
Tensor [B, C, H, W] | Stacked pixel values. When with_prior_preservation=True, shape is [2*batch_size, C, H, W] (instance + class concatenated).
|
input_ids |
Tensor [B, seq_len] | Concatenated token IDs. When with_prior_preservation=True, shape is [2*batch_size, seq_len].
|
attention_mask |
List[Tensor] | Attention masks (only present when the text encoder uses attention masks). |
Usage Examples
Example 1: Dataset Without Prior Preservation
dataset = DreamBoothDataset(
instance_data_root="./my_subject",
instance_prompt="a photo of sks dog",
tokenizer=tokenizer,
size=512,
)
print(f"Dataset length: {len(dataset)}") # equals num_instance_images
sample = dataset[0]
print(sample["instance_images"].shape) # torch.Size([3, 512, 512])
print(sample["instance_prompt_ids"].shape) # torch.Size([1, 77])
Example 2: Dataset With Prior Preservation and Collation
dataset = DreamBoothDataset(
instance_data_root="./my_subject",
instance_prompt="a photo of sks dog",
tokenizer=tokenizer,
class_data_root="./dog_class",
class_prompt="a photo of dog",
class_num=200,
size=512,
center_crop=True,
)
print(f"Dataset length: {len(dataset)}") # max(num_instance, num_class)
dataloader = DataLoader(
dataset,
batch_size=4,
collate_fn=lambda ex: collate_fn(ex, with_prior_preservation=True),
)
batch = next(iter(dataloader))
print(batch["pixel_values"].shape) # torch.Size([8, 3, 512, 512]) -- 4 instance + 4 class
print(batch["input_ids"].shape) # torch.Size([8, 77])