Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:LaurentMazare Tch rs GAN Example

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Generative Models, Computer Vision
Last Updated 2026-02-08 00:00 GMT

Overview

Implements a Relativistic Deep Convolutional GAN (DCGAN) for image generation, using separate generator and discriminator networks trained adversarially.

Description

This example implements a Relativistic DCGAN based on the approach described at AlexiaJM/RelativisticGAN. The architecture consists of:

Generator: A five-layer transposed convolutional network that maps a 128-dimensional latent vector to a 64x64 RGB image. The architecture is:

  • TransConv2D(128 -> 1024, stride=1, padding=0) + BatchNorm + ReLU
  • TransConv2D(1024 -> 512, stride=2, padding=1) + BatchNorm + ReLU
  • TransConv2D(512 -> 256, stride=2, padding=1) + BatchNorm + ReLU
  • TransConv2D(256 -> 128, stride=2, padding=1) + BatchNorm + ReLU
  • TransConv2D(128 -> 3, stride=2, padding=1) + Tanh

Discriminator: A five-layer convolutional network that maps a 64x64 RGB image to a scalar score. Uses leaky ReLU (slope 0.2) and batch normalization:

  • Conv2D(3 -> 128, stride=2, padding=1) + LeakyReLU
  • Conv2D(128 -> 256, stride=2, padding=1) + BatchNorm + LeakyReLU
  • Conv2D(256 -> 512, stride=2, padding=1) + BatchNorm + LeakyReLU
  • Conv2D(512 -> 1024, stride=2, padding=1) + BatchNorm + LeakyReLU
  • Conv2D(1024 -> 1, stride=1, padding=0)

Training uses a relativistic loss: the discriminator loss compares real predictions against the mean of fake predictions (plus 1) and vice versa, using MSE loss. The generator and discriminator use separate VarStores that are frozen/unfrozen alternately. Both use Adam optimizer with betas (0.5, 0.999) and learning rate 1e-4.

Every 1000 batches, a 4x4 grid of generated images is saved from a fixed noise vector.

Usage

Use this example for learning how to implement GANs with tch-rs, including adversarial training with separate variable stores, transposed convolutions for image generation, and relativistic GAN losses.

Code Reference

Source Location

Signature

fn tr2d(p: nn::Path, c_in: i64, c_out: i64, padding: i64, stride: i64) -> nn::ConvTranspose2D
fn conv2d(p: nn::Path, c_in: i64, c_out: i64, padding: i64, stride: i64) -> nn::Conv2D
fn generator(p: nn::Path) -> impl nn::ModuleT
fn leaky_relu(xs: &Tensor) -> Tensor
fn discriminator(p: nn::Path) -> impl nn::ModuleT
fn mse_loss(x: &Tensor, y: &Tensor) -> Tensor
fn image_matrix(imgs: &Tensor, sz: i64) -> Result<Tensor>
pub fn main() -> Result<()>

Import

// Standalone binary example. Run with:
// cargo run --example gan -- <image-dataset-dir>
use anyhow::{bail, Result};
use tch::{kind, nn, nn::OptimizerConfig, Device, Kind, Tensor};

I/O Contract

Inputs

Name Type Required Description
image-dataset-dir CLI argument (String) Yes Path to a directory of training images, loaded and resized to 64x64.

Outputs

Name Type Description
reloutN.png Image files Generated image grids saved every 1000 batches.
stdout Text Batch index printed every 100 batches.

Usage Examples

use anyhow::{bail, Result};
use tch::{kind, nn, nn::OptimizerConfig, Device, Kind, Tensor};

const IMG_SIZE: i64 = 64;
const LATENT_DIM: i64 = 128;
const BATCH_SIZE: i64 = 32;
const LEARNING_RATE: f64 = 1e-4;

// Generator builds images from latent vectors
fn generator(p: nn::Path) -> impl nn::ModuleT {
    let cfg = |s, pad| nn::ConvTransposeConfig { stride: s, padding: pad, bias: false, ..Default::default() };
    nn::seq_t()
        .add(nn::conv_transpose2d(&p / "tr1", LATENT_DIM, 1024, 4, cfg(1, 0)))
        .add(nn::batch_norm2d(&p / "bn1", 1024, Default::default()))
        .add_fn(|xs| xs.relu())
        // ... additional layers ...
        .add_fn(|xs| xs.tanh())
}

// Adversarial training with separate variable stores
pub fn main() -> Result<()> {
    let device = Device::cuda_if_available();

    let mut generator_vs = nn::VarStore::new(device);
    let generator = generator(generator_vs.root());
    let mut opt_g = nn::adam(0.5, 0.999, 0.).build(&generator_vs, LEARNING_RATE)?;

    let mut discriminator_vs = nn::VarStore::new(device);
    // ... training loop alternating generator/discriminator updates ...
    Ok(())
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment