Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:LaurentMazare Tch rs Stable Diffusion Pipeline

From Leeroopedia


Overview

main.rs is a self-contained Rust implementation of the Stable Diffusion v1.4 text-to-image generation pipeline, located at examples/stable-diffusion/main.rs (2512 lines) in the tch-rs repository. It implements the complete inference pipeline from text prompt to generated image, including a BPE tokenizer, CLIP text encoder, Variational Autoencoder (VAE), UNet with cross-attention, and a DDIM noise scheduler. The implementation is inspired by Hugging Face's diffusers Python library and demonstrates how to build a full generative AI pipeline purely in Rust using tch-rs bindings to libtorch.

The file is organized as a single module containing all pipeline components. As noted in the source code's TODO comments, the authors acknowledge this should eventually be split into separate modules similar to the huggingface/diffusers structure.

Code Reference

Global Configuration Constants

Constant Value Purpose
VOCAB_SIZE 49,408 CLIP tokenizer vocabulary size
EMBED_DIM 768 CLIP hidden size / text embedding dimension
INTERMEDIATE_SIZE 3,072 CLIP feed-forward intermediate dimension
MAX_POSITION_EMBEDDINGS 77 Maximum token sequence length for CLIP
NUM_HIDDEN_LAYERS 12 Number of transformer layers in CLIP
NUM_ATTENTION_HEADS 12 Number of attention heads in CLIP
HEIGHT 512 Output image height in pixels
WIDTH 512 Output image width in pixels
GUIDANCE_SCALE 7.5 Classifier-free guidance scale

BPE Tokenizer

struct Tokenizer {
    re: regex::Regex,
    encoder: HashMap<String, usize>,
    decoder: HashMap<usize, String>,
    bpe_ranks: HashMap<(String, String), usize>,
    start_of_text_token: usize,
    end_of_text_token: usize,
}

A byte-pair encoding tokenizer compatible with OpenAI's CLIP tokenizer. Loads vocabulary from data/bpe_simple_vocab_16e6.txt.

Key methods:

  • fn create(bpe_path: T) -> anyhow::Result<Tokenizer> -- Parses the BPE vocabulary file, constructs encoder/decoder maps and BPE merge ranks. Builds the vocabulary from a byte-to-unicode mapping (256 entries), their end-of-word variants, all BPE merges, and special tokens (<|startoftext|>, <|endoftext|>).
  • fn bpe(&self, token: &str) -> Vec<usize> -- Applies BPE merges to a single token, iteratively merging the highest-priority character pair until no more merges apply.
  • fn encode(&self, s: &str, pad_size_to: Option<usize>) -> anyhow::Result<Vec<usize>> -- Lowercases the input, tokenizes with a regex pattern, applies BPE to each token, and pads/truncates to the specified length with end-of-text tokens.
  • fn decode(&self, tokens: &[usize]) -> String -- Decodes token IDs back to text, replacing </w> markers with spaces.

CLIP Text Encoder

The CLIP text model is implemented as a standard transformer encoder with these components:

ClipTextEmbeddings:

struct ClipTextEmbeddings {
    token_embedding: nn::Embedding,
    position_embedding: nn::Embedding,
    position_ids: Tensor,
}

Combines token embeddings (vocab size 49,408, dim 768) with learned positional embeddings (max length 77).

ClipAttention:

struct ClipAttention {
    k_proj: nn::Linear,
    v_proj: nn::Linear,
    q_proj: nn::Linear,
    out_proj: nn::Linear,
    head_dim: i64,
    scale: f64,
    num_attention_heads: i64,
}

Multi-head self-attention with separate Q, K, V projections and a causal attention mask.

ClipEncoderLayer -- Combines ClipAttention and ClipMlp (using quick_gelu activation) with layer normalization in a pre-norm arrangement.

ClipEncoder -- A stack of NUM_HIDDEN_LAYERS (12) encoder layers.

ClipTextTransformer:

struct ClipTextTransformer {
    embeddings: ClipTextEmbeddings,
    encoder: ClipEncoder,
    final_layer_norm: nn::LayerNorm,
}

The complete CLIP text model. The forward method applies embeddings, runs through all encoder layers, and applies final layer normalization. Outputs shape [batch, 77, 768].

Variational Autoencoder (VAE)

AutoEncoderKL:

struct AutoEncoderKL {
    encoder: Encoder,
    decoder: Decoder,
    quant_conv: nn::Conv2D,
    post_quant_conv: nn::Conv2D,
    config: AutoEncoderKLConfig,
}

The VAE maps between pixel space (512x512x3) and latent space (64x64x4).

Key methods:

  • fn new(vs, in_channels, out_channels, config) -> Self -- Constructs encoder, decoder, and quantization convolutions.
  • fn encode(&self, xs: &Tensor) -> Tensor -- Encodes an image into latent distribution parameters.
  • fn decode(&self, xs: &Tensor) -> Tensor -- Decodes latent vectors back to pixel space via post_quant_conv -> decoder.

AutoEncoderKLConfig:

struct AutoEncoderKLConfig {
    block_out_channels: Vec<i64>,  // default: [128, 256, 512, 512]
    layers_per_block: i64,          // default: 2
    latent_channels: i64,           // default: 4
    norm_num_groups: i64,           // default: 32
}

The Encoder and Decoder are built from DownEncoderBlock2D and UpDecoderBlock2D blocks respectively, each containing ResNet blocks with group normalization. A UNetMidBlock2D sits between the down and up paths.

UNet with Cross-Attention

UNet2DConditionModel:

struct UNet2DConditionModel {
    conv_in: nn::Conv2D,
    time_proj: Timesteps,
    time_embedding: TimestepEmbedding,
    down_blocks: Vec<UNetDownBlock>,
    mid_block: UNetMidBlock2DCrossAttn,
    up_blocks: Vec<UNetUpBlock>,
    conv_norm_out: nn::GroupNorm,
    conv_out: nn::Conv2D,
    config: UNet2DConditionModelConfig,
}

The noise prediction network conditioned on text embeddings and diffusion timestep.

Architecture details:

  • Input: 4-channel latent tensor of shape [B, 4, 64, 64]
  • Time embedding: Sinusoidal Timesteps projected through a 2-layer MLP (TimestepEmbedding) to dimension 4 * base_channels
  • Down blocks: 4 blocks with channel dimensions [320, 640, 1280, 1280]. The first 3 use cross-attention (CrossAttnDownBlock2D); the last is a basic DownBlock2D
  • Mid block: UNetMidBlock2DCrossAttn at 1280 channels with cross-attention to text embeddings
  • Up blocks: 4 blocks mirroring the down blocks with skip connections
  • Output: 4-channel tensor, same spatial dimensions as input

Key methods:

  • fn forward(&self, xs: &Tensor, timestep: f64, encoder_hidden_states: &Tensor) -> Tensor -- Runs the full UNet pass with timestep conditioning and cross-attention to text embeddings. Manages skip connections between down and up blocks.

Supporting types:

  • CrossAttention -- Implements both self-attention and cross-attention (when the context dimension differs from the query dimension)
  • BasicTransformerBlock -- Combines self-attention, cross-attention, and feed-forward with layer norms
  • SpatialTransformer -- Applies transformer blocks to spatial feature maps (reshape to sequence, transform, reshape back)
  • ResnetBlock2D -- ResNet block with time embedding injection and optional shortcut projection
  • GeGlu -- Gated linear unit with GELU activation, used in feed-forward layers

DDIM Scheduler

struct DDIMScheduler {
    timesteps: Vec<usize>,
    alphas_cumprod: Vec<f64>,
    step_ratio: usize,
    config: DDIMSchedulerConfig,
}

Implements the Denoising Diffusion Implicit Models (DDIM) scheduler for accelerated sampling.

DDIMSchedulerConfig:

struct DDIMSchedulerConfig {
    beta_start: f64,    // default: 0.00085
    beta_end: f64,      // default: 0.012
    beta_schedule: BetaSchedule,  // default: ScaledLinear
    eta: f64,           // default: 0.0 (deterministic)
}

Key methods:

  • fn new(inference_steps, train_timesteps, config) -> Self -- Computes the noise schedule. Supports Linear and ScaledLinear beta schedules. Computes cumulative alpha products for all training timesteps.
  • fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor -- Performs one DDIM denoising step. When eta = 0 the process is deterministic; when eta > 0 stochastic noise is injected.

Pipeline Builder Functions

fn build_clip_transformer(device: Device) -> anyhow::Result<ClipTextTransformer>
fn build_vae(device: Device) -> anyhow::Result<AutoEncoderKL>
fn build_unet(device: Device) -> anyhow::Result<UNet2DConditionModel>

Each function creates a VarStore on the specified device, constructs the model with the Stable Diffusion v1.4 configuration, and loads pre-trained weights from .ot files in the data/ directory.

Main Entry Point

fn main() -> anyhow::Result<()>

Executes the full text-to-image pipeline:

  1. Detects CUDA availability; accepts an optional text prompt and "cpu" flag from command-line arguments
  2. Creates a DDIM scheduler with 30 inference steps and 1000 training timesteps
  3. Tokenizes the prompt (and an empty unconditional prompt) to 77 tokens each
  4. Builds and loads the CLIP text encoder; computes text embeddings for both prompts
  5. Builds and loads the VAE and UNet
  6. Initializes random latents of shape [1, 4, 64, 64] with seed 32
  7. Iterates through scheduler timesteps:
    • Duplicates latents for classifier-free guidance (conditional + unconditional)
    • Predicts noise with the UNet conditioned on text embeddings
    • Applies guidance: noise = noise_uncond + scale * (noise_text - noise_uncond)
    • Updates latents via the DDIM scheduler step
    • Decodes latents through the VAE, scales to [0, 255] uint8, and saves as sd_{step}.png

I/O Contract

Inputs

  • Text prompt: A string describing the desired image (default: "A rusty robot holding a fire torch in its hand")
  • Device selection: Pass "cpu" as a command-line argument to force CPU inference; otherwise CUDA is used if available
  • Pre-trained weights: Three .ot weight files must be present in data/:
    • data/pytorch_model.ot -- CLIP text encoder weights
    • data/vae.ot -- VAE encoder/decoder weights
    • data/unet.ot -- UNet weights
  • Vocabulary file: data/bpe_simple_vocab_16e6.txt for the BPE tokenizer

Outputs

  • Generated images: One PNG file per denoising step, named sd_0.png through sd_29.png, each 512x512 pixels RGB
  • Console output: CUDA/cuDNN availability, decoded prompt text, token tensor, and per-step progress

Invariants

  • Token sequences are always padded or truncated to exactly 77 tokens (MAX_POSITION_EMBEDDINGS)
  • Latents operate at 1/8 spatial resolution (64x64 for 512x512 output)
  • The VAE scaling factor of 0.18215 is applied before decoding: latents / 0.18215
  • Classifier-free guidance requires two forward passes per step (unconditional and conditional)
  • All inference runs inside tch::no_grad_guard() to disable gradient computation
  • Deterministic generation is ensured by tch::manual_seed(32)

Pipeline Data Flow

Text Prompt
    |
    v
[BPE Tokenizer] --> token IDs (77 integers)
    |
    v
[CLIP Text Encoder] --> text embeddings [1, 77, 768]
    |                    + unconditional embeddings [1, 77, 768]
    |                    --> concatenated [2, 77, 768]
    |
    v
[Random Latents] --> z ~ N(0,1), shape [1, 4, 64, 64]
    |
    v
[Denoising Loop (30 steps)]
    |
    |   For each timestep:
    |   1. Duplicate latents --> [2, 4, 64, 64]
    |   2. UNet(latents, timestep, text_embeddings) --> noise prediction [2, 4, 64, 64]
    |   3. Split into unconditional and conditional noise
    |   4. Apply classifier-free guidance
    |   5. DDIM step to update latents
    |   6. VAE decode --> image [1, 3, 512, 512]
    |   7. Save as PNG
    |
    v
Output images (sd_0.png ... sd_29.png)

Dependencies

  • tch -- Rust bindings to libtorch for tensor operations, neural network modules (nn), convolutions, attention, and model weight loading
  • regex -- Regular expression matching for BPE tokenization
  • anyhow -- Error handling throughout the pipeline
  • std::collections::{HashMap, HashSet} -- Vocabulary and BPE rank storage
  • std::io::BufRead -- Reading the BPE vocabulary file

Weight Preparation

Pre-trained weights must be converted from PyTorch .bin format to tch-rs .ot format:

  1. Download PyTorch weights from Hugging Face (CLIP, VAE, UNet)
  2. Convert to NumPy .npz files using a Python script
  3. Convert to .ot format using the tensor-tools example: cargo run --release --example tensor-tools cp ./data/model.npz ./data/model.ot

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment