Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Junyanz Pytorch CycleGAN and pix2pix TestModel Forward

From Leeroopedia


Knowledge Sources
Domains Vision, Inference
Last Updated 2026-02-09 16:00 GMT

Overview

Concrete tool for running single-direction image translation inference provided by the pytorch-CycleGAN-and-pix2pix framework. The TestModel class wraps a single generator network and exposes a minimal interface for loading a checkpoint, ingesting input images, and producing translated outputs. It is designed to be used with the test.py driver script, which orchestrates the full inference loop including dataset creation, HTML gallery output, and optional evaluation mode.

Code Reference

Source Files

File Lines Description
models/test_model.py L1-69 TestModel class: single-direction inference model
models/base_model.py L139-147 BaseModel.test(): wraps forward in no_grad()
models/base_model.py L80-131 BaseModel.setup(): loads checkpoint, wraps with DDP if needed
test.py L1-79 Main inference driver script

Class Signature

class TestModel(BaseModel):
    """Single-direction inference model for CycleGAN/pix2pix."""

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        ...

    def __init__(self, opt):
        ...

    def set_input(self, input):
        ...

    def forward(self):
        ...

    def optimize_parameters(self):
        pass

Import

from models import create_model
# Internally:
# from .base_model import BaseModel
# from . import networks

Key Methods

modify_commandline_options(parser, is_train) (L13-30): A static method that enforces test-only usage (assert not is_train), sets dataset_mode='single', and adds the --model_suffix argument. The suffix determines which generator checkpoint to load (e.g., _A loads latest_net_G_A.pth, _B loads latest_net_G_B.pth).

__init__(self, opt) (L32-50): Asserts opt.isTrain is False. Calls BaseModel.__init__. Sets self.loss_names = [] (no losses at test time). Sets self.visual_names = ['real', 'fake'] for gallery output. Sets self.model_names = ['G' + opt.model_suffix] so that only the single target generator is loaded. Creates the generator network via networks.define_G(). Uses setattr to assign the network as self.netG_{suffix} so that BaseModel.setup() can locate it by the suffixed name for checkpoint loading.

set_input(self, input) (L52-61): Unpacks only the 'A' key from the input dictionary, moving the tensor to the model device as self.real. Stores the image paths from input['A_paths'] into self.image_paths.

forward(self) (L63-65): Runs the generator forward pass: self.fake = self.netG(self.real).

BaseModel.test() (base_model.py L139-147): Wraps the call to self.forward() and self.compute_visuals() inside a torch.no_grad() context manager so that no gradient computation occurs.

BaseModel.setup(opt) (base_model.py L80-131): Iterates over self.model_names, initializes each network, loads checkpoint weights from disk if not training (or if continue_train), moves networks to device, and wraps with DistributedDataParallel if distributed training is initialized.

I/O Contract

Inputs

Input Type Description
opt Namespace Parsed options object. Must include model_suffix (str, e.g., '_A' or ), isTrain=False, and all standard BaseOptions/TestOptions fields (checkpoints_dir, name, epoch, input_nc, output_nc, ngf, netG, norm, no_dropout, init_type, init_gain, etc.).
input (to set_input) dict Dictionary with keys: 'A' (torch.Tensor of shape [1, C, H, W], the input image) and 'A_paths' (list of str, filesystem paths to the source images).

Outputs

Output Type Description
self.fake torch.Tensor Generated output image tensor of shape [1, C, H, W] with values in [-1, 1] (Tanh output).
self.image_paths list[str] Filesystem paths of the corresponding input images, returned by get_image_paths().
get_current_visuals() OrderedDict Dictionary mapping 'real' to self.real and 'fake' to self.fake.

Usage Examples

CycleGAN: Single-Direction Inference (A to B)

Translate horse images to zebra images using a pretrained CycleGAN generator G_A:

python test.py \
    --dataroot datasets/horse2zebra/testA \
    --name horse2zebra_pretrained \
    --model test \
    --model_suffix _A \
    --no_dropout

This invokes TestModel with model_suffix='_A', which loads latest_net_G_A.pth from checkpoints/horse2zebra_pretrained/. The dataset mode is automatically set to single, so images are read directly from the testA directory. Results are saved to results/horse2zebra_pretrained/test_latest/.

CycleGAN: Single-Direction Inference (B to A)

Translate zebra images back to horse images using generator G_B:

python test.py \
    --dataroot datasets/horse2zebra/testB \
    --name horse2zebra_pretrained \
    --model test \
    --model_suffix _B \
    --no_dropout

Pix2pix: Paired Inference

Translate facade label maps to photos using a trained pix2pix generator:

python test.py \
    --dataroot ./datasets/facades \
    --name facades_pix2pix \
    --model pix2pix \
    --direction BtoA

This uses the full Pix2PixModel rather than TestModel. The generator runs in the same manner (forward-only under no_grad()), but the dataset mode is aligned and both input and ground truth are available for visual comparison in the output gallery.

Programmatic Usage

The test loop in test.py follows this structure:

opt = TestOptions().parse()
dataset = create_dataset(opt)
model = create_model(opt)      # Returns TestModel when --model test
model.setup(opt)               # Loads checkpoint weights

if opt.eval:
    model.eval()               # Switch BatchNorm/Dropout to eval mode

for i, data in enumerate(dataset):
    if i >= opt.num_test:
        break
    model.set_input(data)      # Unpack input['A'] -> self.real
    model.test()               # no_grad + forward + compute_visuals
    visuals = model.get_current_visuals()   # {'real': tensor, 'fake': tensor}
    img_path = model.get_image_paths()
    save_images(webpage, visuals, img_path)

webpage.save()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment