Implementation:LaurentMazare Tch rs GAN Example
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Generative Models, Computer Vision |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements a Relativistic Deep Convolutional GAN (DCGAN) for image generation, using separate generator and discriminator networks trained adversarially.
Description
This example implements a Relativistic DCGAN based on the approach described at AlexiaJM/RelativisticGAN. The architecture consists of:
Generator: A five-layer transposed convolutional network that maps a 128-dimensional latent vector to a 64x64 RGB image. The architecture is:
- TransConv2D(128 -> 1024, stride=1, padding=0) + BatchNorm + ReLU
- TransConv2D(1024 -> 512, stride=2, padding=1) + BatchNorm + ReLU
- TransConv2D(512 -> 256, stride=2, padding=1) + BatchNorm + ReLU
- TransConv2D(256 -> 128, stride=2, padding=1) + BatchNorm + ReLU
- TransConv2D(128 -> 3, stride=2, padding=1) + Tanh
Discriminator: A five-layer convolutional network that maps a 64x64 RGB image to a scalar score. Uses leaky ReLU (slope 0.2) and batch normalization:
- Conv2D(3 -> 128, stride=2, padding=1) + LeakyReLU
- Conv2D(128 -> 256, stride=2, padding=1) + BatchNorm + LeakyReLU
- Conv2D(256 -> 512, stride=2, padding=1) + BatchNorm + LeakyReLU
- Conv2D(512 -> 1024, stride=2, padding=1) + BatchNorm + LeakyReLU
- Conv2D(1024 -> 1, stride=1, padding=0)
Training uses a relativistic loss: the discriminator loss compares real predictions against the mean of fake predictions (plus 1) and vice versa, using MSE loss. The generator and discriminator use separate VarStores that are frozen/unfrozen alternately. Both use Adam optimizer with betas (0.5, 0.999) and learning rate 1e-4.
Every 1000 batches, a 4x4 grid of generated images is saved from a fixed noise vector.
Usage
Use this example for learning how to implement GANs with tch-rs, including adversarial training with separate variable stores, transposed convolutions for image generation, and relativistic GAN losses.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: examples/gan/main.rs
- Lines: 1-148
Signature
fn tr2d(p: nn::Path, c_in: i64, c_out: i64, padding: i64, stride: i64) -> nn::ConvTranspose2D
fn conv2d(p: nn::Path, c_in: i64, c_out: i64, padding: i64, stride: i64) -> nn::Conv2D
fn generator(p: nn::Path) -> impl nn::ModuleT
fn leaky_relu(xs: &Tensor) -> Tensor
fn discriminator(p: nn::Path) -> impl nn::ModuleT
fn mse_loss(x: &Tensor, y: &Tensor) -> Tensor
fn image_matrix(imgs: &Tensor, sz: i64) -> Result<Tensor>
pub fn main() -> Result<()>
Import
// Standalone binary example. Run with:
// cargo run --example gan -- <image-dataset-dir>
use anyhow::{bail, Result};
use tch::{kind, nn, nn::OptimizerConfig, Device, Kind, Tensor};
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| image-dataset-dir | CLI argument (String) | Yes | Path to a directory of training images, loaded and resized to 64x64. |
Outputs
| Name | Type | Description |
|---|---|---|
| reloutN.png | Image files | Generated image grids saved every 1000 batches. |
| stdout | Text | Batch index printed every 100 batches. |
Usage Examples
use anyhow::{bail, Result};
use tch::{kind, nn, nn::OptimizerConfig, Device, Kind, Tensor};
const IMG_SIZE: i64 = 64;
const LATENT_DIM: i64 = 128;
const BATCH_SIZE: i64 = 32;
const LEARNING_RATE: f64 = 1e-4;
// Generator builds images from latent vectors
fn generator(p: nn::Path) -> impl nn::ModuleT {
let cfg = |s, pad| nn::ConvTransposeConfig { stride: s, padding: pad, bias: false, ..Default::default() };
nn::seq_t()
.add(nn::conv_transpose2d(&p / "tr1", LATENT_DIM, 1024, 4, cfg(1, 0)))
.add(nn::batch_norm2d(&p / "bn1", 1024, Default::default()))
.add_fn(|xs| xs.relu())
// ... additional layers ...
.add_fn(|xs| xs.tanh())
}
// Adversarial training with separate variable stores
pub fn main() -> Result<()> {
let device = Device::cuda_if_available();
let mut generator_vs = nn::VarStore::new(device);
let generator = generator(generator_vs.root());
let mut opt_g = nn::adam(0.5, 0.999, 0.).build(&generator_vs, LEARNING_RATE)?;
let mut discriminator_vs = nn::VarStore::new(device);
// ... training loop alternating generator/discriminator updates ...
Ok(())
}