Implementation:Junyanz Pytorch CycleGAN and pix2pix ImagePool Query
| Knowledge Sources | pytorch-CycleGAN-and-pix2pix |
|---|---|
| Domains | Image-to-Image Translation, GAN Training Stabilization, Generative Adversarial Networks |
| Last Updated | 2026-02-09 |
Overview
The ImagePool class in util/image_pool.py implements a fixed-size buffer of previously generated images used to stabilize discriminator training in CycleGAN.
Description
The ImagePool class maintains an internal list of up to pool_size image tensors. When query() is called with a batch of newly generated images, the pool applies a stochastic replacement policy: each image in the batch either passes through unchanged or is swapped with a randomly selected historical image from the pool. The returned batch is then used to compute the discriminator loss.
If pool_size is set to 0, the pool is disabled and query() simply returns the input images unchanged. This is the behavior used by pix2pix.
Usage
Instantiated in CycleGANModel.__init__ with pool_size=50 (default). Called in CycleGANModel.backward_D_basic to retrieve images for discriminator updates.
Code Reference
Source Location
| File | Lines |
|---|---|
| util/image_pool.py | L5-54 |
Signature
class ImagePool():
def __init__(self, pool_size):
"""Initialize the ImagePool class.
Parameters:
pool_size (int) -- the size of image buffer; if pool_size=0, no buffer will be created
"""
def query(self, images):
"""Return an image from the pool.
Parameters:
images (torch.Tensor) -- the latest generated images from the generator
Returns:
return_images (torch.Tensor) -- images from the buffer (may include historical images)
By 50/50 chance, the buffer will return a previously stored image
and insert the current image into the buffer.
"""
Import
from util.image_pool import ImagePool
fake_A_pool = ImagePool(pool_size=50)
fake_B_pool = ImagePool(pool_size=50)
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| pool_size | int | Maximum number of images to store in the buffer (0 disables the pool) |
| images | torch.Tensor | Batch of generated images from the generator, shape (N, C, H, W) |
| Output | Type | Description |
|---|---|---|
| return_images | torch.Tensor | Batch of images to use for discriminator update; may contain historical images swapped from the pool |
Usage Examples
from util.image_pool import ImagePool
# Create pools for CycleGAN (one per direction)
fake_A_pool = ImagePool(pool_size=50)
fake_B_pool = ImagePool(pool_size=50)
# During training, query the pool before discriminator update
fake_B = netG_A(real_A) # Generate fake B from real A
fake_B_pooled = fake_B_pool.query(fake_B) # May return historical fake_B
# Use pooled images for discriminator loss
pred_fake = netD_B(fake_B_pooled.detach())
loss_D_B = criterionGAN(pred_fake, False)
# Disabled pool for pix2pix
no_pool = ImagePool(pool_size=0)
same_images = no_pool.query(generated_images) # Returns input unchanged