Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Junyanz Pytorch CycleGAN and pix2pix Create Dataset

From Leeroopedia
Revision as of 15:20, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Junyanz_Pytorch_CycleGAN_and_pix2pix_Create_Dataset.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Template:Metadata

Overview

Concrete tool for dynamically loading image datasets with multi-threaded DataLoader wrapping provided by the pytorch-CycleGAN-and-pix2pix framework. This implementation lives in data/__init__.py and exposes the create_dataset() factory function along with the CustomDatasetDataLoader wrapper class.

Code Reference

Source file: data/__init__.py (lines 50--107)

Import:

from data import create_dataset

find_dataset_using_name (L22--41)

def find_dataset_using_name(dataset_name):

Dynamically imports the module data.{dataset_name}_dataset using importlib.import_module. Scans the module's namespace with a case-insensitive match to locate a class whose lowercased name equals {dataset_name}dataset and that is a subclass of BaseDataset. Raises NotImplementedError if no matching class is found.

create_dataset (L50--62)

def create_dataset(opt):
    """Create a dataset given the option.

    This function wraps the class CustomDatasetDataLoader.
        This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from data import create_dataset
        >>> dataset = create_dataset(opt)
    """
    data_loader = CustomDatasetDataLoader(opt)
    dataset = data_loader.load_data()
    return dataset

Creates a CustomDatasetDataLoader instance and immediately calls load_data(), which returns the loader itself. This is the primary public entry point.

CustomDatasetDataLoader (L65--107)

class CustomDatasetDataLoader:
    """Wrapper class of Dataset class that performs multi-threaded data loading"""

__init__(self, opt) (L68--88)

  1. Calls find_dataset_using_name(opt.dataset_mode) to resolve the dataset class.
  2. Instantiates the dataset class with opt.
  3. Checks for the LOCAL_RANK environment variable to determine if DDP is active.
  4. If DDP is active, creates a DistributedSampler with shuffle=not opt.serial_batches and disables the DataLoader's own shuffle.
  5. Otherwise, sets shuffle=not opt.serial_batches directly on the DataLoader.
  6. Constructs a torch.utils.data.DataLoader with batch_size=opt.batch_size, num_workers=opt.num_threads, and the resolved sampler/shuffle configuration.

__len__(self) (L93--95)

def __len__(self):
    return min(len(self.dataset), self.opt.max_dataset_size)

Returns the effective dataset length, capped at opt.max_dataset_size.

__iter__(self) (L97--102)

def __iter__(self):
    for i, data in enumerate(self.dataloader):
        if i * self.opt.batch_size >= self.opt.max_dataset_size:
            break
        yield data

Iterates over the underlying DataLoader, yielding batches until max_dataset_size samples have been produced.

set_epoch(self, epoch) (L104--107)

def set_epoch(self, epoch):
    if self.sampler is not None:
        self.sampler.set_epoch(epoch)

Sets the epoch on the DistributedSampler to ensure proper shuffling across epochs in DDP training. This is a no-op when no sampler is present.

I/O Contract

Inputs

The opt parameter is an argparse.Namespace (or equivalent) object. The following fields are consumed by the dataset factory:

Field Type Description
dataset_mode str Name of the dataset mode (e.g., "unaligned", "aligned", "single", "colorization").
dataroot str Root directory containing the image data.
batch_size int Number of samples per mini-batch.
num_threads int Number of data-loading worker threads.
serial_batches bool If True, disables shuffling.
max_dataset_size int Maximum number of samples to use from the dataset.
phase str Current phase ("train", "test", "val"), used by dataset classes to locate subdirectories.
load_size int Scale images to this size before cropping.
crop_size int Crop images to this size after scaling.
preprocess str Preprocessing strategy (e.g., "resize_and_crop", "scale_width_and_crop", "none").
no_flip bool If True, disables random horizontal flipping.
input_nc int Number of input image channels.
output_nc int Number of output image channels.
direction str Translation direction ("AtoB" or "BtoA").

Outputs

create_dataset(opt) returns a CustomDatasetDataLoader instance that is:

  • Iterable: Each iteration yields a dict whose keys depend on the dataset mode. For example:
    • unaligned: {'A': tensor, 'B': tensor, 'A_paths': list, 'B_paths': list}
    • aligned: {'A': tensor, 'B': tensor, 'A_paths': list, 'B_paths': list}
    • single: {'A': tensor, 'A_paths': list}
    • colorization: {'A': tensor, 'B': tensor, 'hint_B': tensor, 'mask_B': tensor}
  • Length-aware: len(dataset) returns min(len(underlying_dataset), opt.max_dataset_size).

Usage Examples

CycleGAN (unaligned mode)

from data import create_dataset

# opt.dataset_mode = 'unaligned'
# opt.dataroot = './datasets/horse2zebra'
# opt.batch_size = 1
# opt.num_threads = 4
dataset = create_dataset(opt)
for i, data in enumerate(dataset):
    real_A = data['A']   # images from domain A
    real_B = data['B']   # images from domain B

pix2pix (aligned mode)

from data import create_dataset

# opt.dataset_mode = 'aligned'
# opt.dataroot = './datasets/facades'
# opt.batch_size = 4
# opt.num_threads = 4
dataset = create_dataset(opt)
for i, data in enumerate(dataset):
    input_image = data['A']    # input side of the pair
    target_image = data['B']   # target side of the pair

DDP Training (multi-GPU)

When launched via torchrun, the factory automatically detects DDP and attaches a DistributedSampler. Call set_epoch at the start of each epoch:

dataset = create_dataset(opt)
for epoch in range(start_epoch, n_epochs):
    dataset.set_epoch(epoch)
    for i, data in enumerate(dataset):
        # training step
        pass

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment