Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Junyanz Pytorch CycleGAN and pix2pix Dataset Factory Loading

From Leeroopedia
Revision as of 18:07, 16 February 2026 by Admin (talk | contribs) (Auto-imported from principles/Junyanz_Pytorch_CycleGAN_and_pix2pix_Dataset_Factory_Loading.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Template:Metadata

Overview

A factory pattern that dynamically discovers, instantiates, and wraps dataset classes into multi-threaded data loaders based on string configuration.

Description

The framework uses importlib to dynamically load dataset classes by name at runtime. Given a string value in opt.dataset_mode (e.g., "unaligned", "aligned", "single"), the factory performs the following steps:

  1. Constructs the module path data.{dataset_mode}_dataset and imports it via importlib.import_module.
  2. Iterates over the imported module's namespace to find a class whose lowercased name matches {dataset_mode}dataset and is a subclass of BaseDataset. The matching is case-insensitive.
  3. Instantiates the discovered class with the opt namespace, which carries all experiment configuration flags (e.g., dataroot, load_size, crop_size, preprocess).
  4. Wraps the instantiated dataset in a CustomDatasetDataLoader, which creates a PyTorch DataLoader with configurable batch_size, num_workers, and optional shuffling.
  5. When running under Distributed Data Parallel (DDP), detected via the LOCAL_RANK environment variable, a DistributedSampler is attached to the DataLoader to partition data across processes.

The public entry point is the create_dataset(opt) function, which returns a CustomDatasetDataLoader instance that is iterable and length-aware.

Usage

This factory is invoked before every training or testing run to load the appropriate dataset format for the selected model. In train.py and test.py, the single call dataset = create_dataset(opt) handles all dataset discovery, construction, and DataLoader wrapping. No additional setup is required by the caller.

Theoretical Basis

Factory Pattern

The Dataset Factory Loading principle applies the factory method pattern to decouple the training loop from concrete dataset implementations. The caller never directly imports or references a specific dataset class. Instead, it passes a string identifier (opt.dataset_mode) to the factory, which resolves and returns the correct implementation. This allows new dataset formats to be added by simply creating a new Python file with the correct naming convention, without modifying any existing code.

Dataset Modes

The framework ships with four dataset modes, each corresponding to a BaseDataset subclass:

Mode Class Model Description
unaligned UnalignedDataset CycleGAN Loads unpaired images from two separate directories (trainA/, trainB/). Images from domain A and domain B are sampled independently.
aligned AlignedDataset pix2pix Loads paired images where each image file contains both the input and target side-by-side, split at the midpoint.
single SingleDataset Test / Inference Loads images from a single directory for one-sided generation (e.g., applying a trained CycleGAN to new images).
colorization ColorizationDataset pix2pix (colorization) Loads RGB images and converts them to Lab color space, yielding (L, ab) pairs for colorization training.

Dynamic Import via importlib

The find_dataset_using_name function uses Python's importlib.import_module to load the module data.{name}_dataset at runtime. It then performs a case-insensitive scan of the module's __dict__ to locate the target class, verifying it is a subclass of BaseDataset. This avoids maintaining a static registry and enables plug-in extensibility.

DataLoader Configuration

The CustomDatasetDataLoader constructs a torch.utils.data.DataLoader with the following parameters derived from opt:

  • batch_size: Number of samples per batch (opt.batch_size).
  • shuffle: Enabled by default unless opt.serial_batches is set or DDP is active (in which case the sampler handles shuffling).
  • num_workers: Number of data-loading worker threads (opt.num_threads).

DistributedSampler for Multi-GPU Training

When the LOCAL_RANK environment variable is present (indicating a DDP launch via torchrun), the loader creates a DistributedSampler. This sampler partitions the dataset across all participating processes so each GPU sees a unique subset. The set_epoch(epoch) method must be called at the start of each epoch to re-seed the sampler's random shuffling, ensuring different orderings per epoch.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment