Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Diffusers DreamBooth Dataset Class

From Leeroopedia
Metadata
Knowledge Sources
Domains
Last Updated 2026-02-13 00:00 GMT

Overview

The DreamBoothDataset class and its companion collate_fn function implement the paired instance-class data loading pipeline for DreamBooth training. The dataset handles image loading, EXIF-aware preprocessing, prompt tokenization, and instance-class pairing, while the collate function concatenates instance and class samples for efficient single-pass training.

Description

The DreamBoothDataset extends PyTorch's Dataset class. On initialization, it:

  1. Enumerates all image files in instance_data_root.
  2. Optionally enumerates class images in class_data_root.
  3. Sets the dataset length to max(num_instance_images, num_class_images).
  4. Builds an image transform pipeline: Resize, CenterCrop/RandomCrop, ToTensor, Normalize.

Each __getitem__ call returns a dictionary containing:

  • instance_images -- Preprocessed instance image tensor.
  • instance_prompt_ids -- Tokenized instance prompt (or pre-computed encoder hidden states).
  • class_images -- Preprocessed class image tensor (when class data is provided).
  • class_prompt_ids -- Tokenized class prompt (when class data is provided).

The companion collate_fn handles batching by stacking pixel values and concatenating token IDs. When with_prior_preservation=True, class examples are appended after instance examples in the batch, doubling the effective batch size.

Usage

from torch.utils.data import DataLoader

train_dataset = DreamBoothDataset(
    instance_data_root="./my_dog",
    instance_prompt="a photo of sks dog",
    tokenizer=tokenizer,
    class_data_root="./dog_class",
    class_prompt="a photo of dog",
    class_num=200,
    size=512,
    center_crop=False,
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=lambda examples: collate_fn(examples, with_prior_preservation=True),
    num_workers=0,
)

Code Reference

Source Location

  • Repository: huggingface/diffusers
  • File: examples/dreambooth/train_dreambooth_lora.py (lines 568--701)

Signature

class DreamBoothDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        class_num=None,
        size=512,
        center_crop=False,
        encoder_hidden_states=None,
        class_prompt_encoder_hidden_states=None,
        tokenizer_max_length=None,
    ):
        ...

    def __len__(self) -> int:
        return self._length

    def __getitem__(self, index) -> dict:
        ...


def collate_fn(examples, with_prior_preservation=False) -> dict:
    """
    Collate function that concatenates instance and class examples
    for prior preservation training.
    """
    ...

Import

from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from PIL.ImageOps import exif_transpose

I/O Contract

Inputs (DreamBoothDataset.__init__)

Input Contract
Name Type Default Description
instance_data_root Path (required) Path to directory containing instance (subject) images.
instance_prompt str (required) Text prompt with identifier token for instance images.
tokenizer PreTrainedTokenizer (required) Tokenizer for encoding text prompts into token IDs.
class_data_root Path None Path to directory containing class-prior images.
class_prompt str None Generic class prompt for class-prior images.
class_num int None Maximum number of class images to use. If None, uses all available.
size int 512 Target resolution for image preprocessing (resize and crop).
center_crop bool False Whether to use center crop (True) or random crop (False).
encoder_hidden_states Tensor None Pre-computed text encoder hidden states for instance prompt (bypasses tokenization).
class_prompt_encoder_hidden_states Tensor None Pre-computed text encoder hidden states for class prompt.
tokenizer_max_length int None Maximum tokenizer sequence length. Defaults to the tokenizer's own max length.

Outputs (__getitem__)

Output Contract
Name Type Description
instance_images Tensor [C, H, W] Preprocessed instance image, normalized to [-1, 1].
instance_prompt_ids Tensor [seq_len] Tokenized instance prompt IDs (or pre-computed hidden states).
instance_attention_mask Tensor [seq_len] Attention mask for instance prompt tokens (when not using pre-computed embeddings).
class_images Tensor [C, H, W] Preprocessed class image (only when class_data_root is provided).
class_prompt_ids Tensor [seq_len] Tokenized class prompt IDs (only when class data is provided).
class_attention_mask Tensor [seq_len] Attention mask for class prompt tokens.

Outputs (collate_fn)

Collated Batch Contract
Name Type Description
pixel_values Tensor [B, C, H, W] Stacked pixel values. When with_prior_preservation=True, shape is [2*batch_size, C, H, W] (instance + class concatenated).
input_ids Tensor [B, seq_len] Concatenated token IDs. When with_prior_preservation=True, shape is [2*batch_size, seq_len].
attention_mask List[Tensor] Attention masks (only present when the text encoder uses attention masks).

Usage Examples

Example 1: Dataset Without Prior Preservation

dataset = DreamBoothDataset(
    instance_data_root="./my_subject",
    instance_prompt="a photo of sks dog",
    tokenizer=tokenizer,
    size=512,
)
print(f"Dataset length: {len(dataset)}")  # equals num_instance_images

sample = dataset[0]
print(sample["instance_images"].shape)      # torch.Size([3, 512, 512])
print(sample["instance_prompt_ids"].shape)  # torch.Size([1, 77])

Example 2: Dataset With Prior Preservation and Collation

dataset = DreamBoothDataset(
    instance_data_root="./my_subject",
    instance_prompt="a photo of sks dog",
    tokenizer=tokenizer,
    class_data_root="./dog_class",
    class_prompt="a photo of dog",
    class_num=200,
    size=512,
    center_crop=True,
)
print(f"Dataset length: {len(dataset)}")  # max(num_instance, num_class)

dataloader = DataLoader(
    dataset,
    batch_size=4,
    collate_fn=lambda ex: collate_fn(ex, with_prior_preservation=True),
)

batch = next(iter(dataloader))
print(batch["pixel_values"].shape)  # torch.Size([8, 3, 512, 512]) -- 4 instance + 4 class
print(batch["input_ids"].shape)     # torch.Size([8, 77])

Related Pages

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment