Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Junyanz Pytorch CycleGAN and pix2pix Pix2PixModel Optimize Parameters

From Leeroopedia


Field Value
source Repo: pytorch-CycleGAN-and-pix2pix, Paper: pix2pix
domains Vision, GAN, Training
last_updated 2026-02-09 16:00 GMT

Overview

Concrete tool for performing a single pix2pix conditional GAN training step with adversarial and L1 reconstruction losses.

The optimize_parameters() method on Pix2PixModel executes one complete training iteration: it runs the forward pass through the U-Net generator, updates the PatchGAN discriminator on real and fake pairs, then updates the generator with combined GAN and L1 losses. This method is called once per batch by the training loop.

Code Reference

Source file: models/pix2pix_model.py (lines 6–127)

Class: Pix2PixModel(BaseModel)

Primary method:

def optimize_parameters(self):
    self.forward()                          # compute fake images: G(A)
    # update D
    self.set_requires_grad(self.netD, True) # enable backprop for D
    self.optimizer_D.zero_grad()            # set D's gradients to zero
    self.backward_D()                       # calculate gradients for D
    self.optimizer_D.step()                 # update D's weights
    # update G
    self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
    self.optimizer_G.zero_grad()             # set G's gradients to zero
    self.backward_G()                        # calculate gradients for G
    self.optimizer_G.step()                  # update G's weights

Import and instantiation:

from models.pix2pix_model import Pix2PixModel

model = Pix2PixModel(opt)
model.setup(opt)

Key internal methods (called by optimize_parameters):

Method Lines Purpose
__init__(self, opt) 40–71 Creates netG (U-Net), netD (PatchGAN with input_nc + output_nc channels), GANLoss (vanilla BCE), L1Loss, and Adam optimizers
modify_commandline_options(parser, is_train) 17–38 Sets defaults: norm=batch, netG=unet_256, dataset_mode=aligned, pool_size=0, gan_mode=vanilla, lambda_L1=100.0
set_input(self, input) 73–84 Unpacks real_A and real_B tensors from the data loader dictionary; respects opt.direction
forward(self) 86–88 fake_B = netG(real_A)
backward_D(self) 90–102 Concatenates real_A + fake_B (fake pair) and real_A + real_B (real pair), computes GAN loss for discriminator
backward_G(self) 104–114 Computes GAN loss (fool D) + lambda_L1 * L1(fake_B, real_B) for generator

I/O Contract

Inputs

Input Type Description
opt argparse.Namespace Experiment configuration object passed at construction. Key fields: input_nc, output_nc, ngf, ndf, netG, netD, norm, no_dropout, init_type, init_gain, lr, beta1, lambda_L1, gan_mode, direction, device
data (via set_input) dict Dictionary with keys "A" (input tensor), "B" (target tensor), "A_paths", "B_paths". Tensors have shape [B, C, H, W].

Outputs

Updated model weights (in-place on netG and netD).

Losses (stored as instance attributes, accessible via get_current_losses()):

Loss Name Description
G_GAN Generator adversarial loss — how well G fools D
G_L1 Weighted L1 reconstruction loss: lambda_L1 * fake_B - real_B _1
D_real Discriminator loss on real pairs
D_fake Discriminator loss on fake pairs

Visuals (stored as instance attributes, accessible via get_current_visuals()):

Visual Name Description
real_A Input image (domain A)
fake_B Generated output image: G(real_A)
real_B Ground-truth target image (domain B)

Usage Examples

Standard training loop:

from options.train_options import TrainOptions
from data import create_dataset
from models import create_model

# Parse options (pix2pix defaults applied via modify_commandline_options)
opt = TrainOptions().parse()

# Create aligned dataset and model
dataset = create_dataset(opt)
model = create_model(opt)    # returns Pix2PixModel when --model pix2pix
model.setup(opt)             # load/print networks, create schedulers

for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):
    for i, data in enumerate(dataset):
        model.set_input(data)            # unpack real_A, real_B
        model.optimize_parameters()      # forward + backward_D + backward_G

        if i % opt.print_freq == 0:
            losses = model.get_current_losses()
            # losses = {'G_GAN': ..., 'G_L1': ..., 'D_real': ..., 'D_fake': ...}

        if i % opt.display_freq == 0:
            visuals = model.get_current_visuals()
            # visuals = {'real_A': tensor, 'fake_B': tensor, 'real_B': tensor}

    model.update_learning_rate()         # decay LR per epoch

Command-line invocation:

python train.py --dataroot ./datasets/facades --name facades_pix2pix \
    --model pix2pix --direction BtoA --netG unet_256 --norm batch \
    --lambda_L1 100 --gan_mode vanilla --dataset_mode aligned

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment