Implementation:Junyanz Pytorch CycleGAN and pix2pix Create Dataset
Overview
Concrete tool for dynamically loading image datasets with multi-threaded DataLoader wrapping provided by the pytorch-CycleGAN-and-pix2pix framework. This implementation lives in data/__init__.py and exposes the create_dataset() factory function along with the CustomDatasetDataLoader wrapper class.
Code Reference
Source file: data/__init__.py (lines 50--107)
Import:
from data import create_dataset
find_dataset_using_name (L22--41)
def find_dataset_using_name(dataset_name):
Dynamically imports the module data.{dataset_name}_dataset using importlib.import_module. Scans the module's namespace with a case-insensitive match to locate a class whose lowercased name equals {dataset_name}dataset and that is a subclass of BaseDataset. Raises NotImplementedError if no matching class is found.
create_dataset (L50--62)
def create_dataset(opt):
"""Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)
"""
data_loader = CustomDatasetDataLoader(opt)
dataset = data_loader.load_data()
return dataset
Creates a CustomDatasetDataLoader instance and immediately calls load_data(), which returns the loader itself. This is the primary public entry point.
CustomDatasetDataLoader (L65--107)
class CustomDatasetDataLoader:
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
__init__(self, opt) (L68--88)
- Calls
find_dataset_using_name(opt.dataset_mode)to resolve the dataset class. - Instantiates the dataset class with
opt. - Checks for the
LOCAL_RANKenvironment variable to determine if DDP is active. - If DDP is active, creates a
DistributedSamplerwithshuffle=not opt.serial_batchesand disables the DataLoader's own shuffle. - Otherwise, sets
shuffle=not opt.serial_batchesdirectly on the DataLoader. - Constructs a
torch.utils.data.DataLoaderwithbatch_size=opt.batch_size,num_workers=opt.num_threads, and the resolved sampler/shuffle configuration.
__len__(self) (L93--95)
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
Returns the effective dataset length, capped at opt.max_dataset_size.
__iter__(self) (L97--102)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data
Iterates over the underlying DataLoader, yielding batches until max_dataset_size samples have been produced.
set_epoch(self, epoch) (L104--107)
def set_epoch(self, epoch):
if self.sampler is not None:
self.sampler.set_epoch(epoch)
Sets the epoch on the DistributedSampler to ensure proper shuffling across epochs in DDP training. This is a no-op when no sampler is present.
I/O Contract
Inputs
The opt parameter is an argparse.Namespace (or equivalent) object. The following fields are consumed by the dataset factory:
| Field | Type | Description |
|---|---|---|
dataset_mode |
str |
Name of the dataset mode (e.g., "unaligned", "aligned", "single", "colorization").
|
dataroot |
str |
Root directory containing the image data. |
batch_size |
int |
Number of samples per mini-batch. |
num_threads |
int |
Number of data-loading worker threads. |
serial_batches |
bool |
If True, disables shuffling.
|
max_dataset_size |
int |
Maximum number of samples to use from the dataset. |
phase |
str |
Current phase ("train", "test", "val"), used by dataset classes to locate subdirectories.
|
load_size |
int |
Scale images to this size before cropping. |
crop_size |
int |
Crop images to this size after scaling. |
preprocess |
str |
Preprocessing strategy (e.g., "resize_and_crop", "scale_width_and_crop", "none").
|
no_flip |
bool |
If True, disables random horizontal flipping.
|
input_nc |
int |
Number of input image channels. |
output_nc |
int |
Number of output image channels. |
direction |
str |
Translation direction ("AtoB" or "BtoA").
|
Outputs
create_dataset(opt) returns a CustomDatasetDataLoader instance that is:
- Iterable: Each iteration yields a
dictwhose keys depend on the dataset mode. For example:unaligned:{'A': tensor, 'B': tensor, 'A_paths': list, 'B_paths': list}aligned:{'A': tensor, 'B': tensor, 'A_paths': list, 'B_paths': list}single:{'A': tensor, 'A_paths': list}colorization:{'A': tensor, 'B': tensor, 'hint_B': tensor, 'mask_B': tensor}
- Length-aware:
len(dataset)returnsmin(len(underlying_dataset), opt.max_dataset_size).
Usage Examples
CycleGAN (unaligned mode)
from data import create_dataset
# opt.dataset_mode = 'unaligned'
# opt.dataroot = './datasets/horse2zebra'
# opt.batch_size = 1
# opt.num_threads = 4
dataset = create_dataset(opt)
for i, data in enumerate(dataset):
real_A = data['A'] # images from domain A
real_B = data['B'] # images from domain B
pix2pix (aligned mode)
from data import create_dataset
# opt.dataset_mode = 'aligned'
# opt.dataroot = './datasets/facades'
# opt.batch_size = 4
# opt.num_threads = 4
dataset = create_dataset(opt)
for i, data in enumerate(dataset):
input_image = data['A'] # input side of the pair
target_image = data['B'] # target side of the pair
DDP Training (multi-GPU)
When launched via torchrun, the factory automatically detects DDP and attaches a DistributedSampler. Call set_epoch at the start of each epoch:
dataset = create_dataset(opt)
for epoch in range(start_epoch, n_epochs):
dataset.set_epoch(epoch)
for i, data in enumerate(dataset):
# training step
pass