Implementation:LaurentMazare Tch rs Stable Diffusion Pipeline
Overview
main.rs is a self-contained Rust implementation of the Stable Diffusion v1.4 text-to-image generation pipeline, located at examples/stable-diffusion/main.rs (2512 lines) in the tch-rs repository. It implements the complete inference pipeline from text prompt to generated image, including a BPE tokenizer, CLIP text encoder, Variational Autoencoder (VAE), UNet with cross-attention, and a DDIM noise scheduler. The implementation is inspired by Hugging Face's diffusers Python library and demonstrates how to build a full generative AI pipeline purely in Rust using tch-rs bindings to libtorch.
The file is organized as a single module containing all pipeline components. As noted in the source code's TODO comments, the authors acknowledge this should eventually be split into separate modules similar to the huggingface/diffusers structure.
Code Reference
Global Configuration Constants
| Constant | Value | Purpose |
|---|---|---|
VOCAB_SIZE |
49,408 | CLIP tokenizer vocabulary size |
EMBED_DIM |
768 | CLIP hidden size / text embedding dimension |
INTERMEDIATE_SIZE |
3,072 | CLIP feed-forward intermediate dimension |
MAX_POSITION_EMBEDDINGS |
77 | Maximum token sequence length for CLIP |
NUM_HIDDEN_LAYERS |
12 | Number of transformer layers in CLIP |
NUM_ATTENTION_HEADS |
12 | Number of attention heads in CLIP |
HEIGHT |
512 | Output image height in pixels |
WIDTH |
512 | Output image width in pixels |
GUIDANCE_SCALE |
7.5 | Classifier-free guidance scale |
BPE Tokenizer
struct Tokenizer {
re: regex::Regex,
encoder: HashMap<String, usize>,
decoder: HashMap<usize, String>,
bpe_ranks: HashMap<(String, String), usize>,
start_of_text_token: usize,
end_of_text_token: usize,
}
A byte-pair encoding tokenizer compatible with OpenAI's CLIP tokenizer. Loads vocabulary from data/bpe_simple_vocab_16e6.txt.
Key methods:
fn create(bpe_path: T) -> anyhow::Result<Tokenizer>-- Parses the BPE vocabulary file, constructs encoder/decoder maps and BPE merge ranks. Builds the vocabulary from a byte-to-unicode mapping (256 entries), their end-of-word variants, all BPE merges, and special tokens (<|startoftext|>,<|endoftext|>).fn bpe(&self, token: &str) -> Vec<usize>-- Applies BPE merges to a single token, iteratively merging the highest-priority character pair until no more merges apply.fn encode(&self, s: &str, pad_size_to: Option<usize>) -> anyhow::Result<Vec<usize>>-- Lowercases the input, tokenizes with a regex pattern, applies BPE to each token, and pads/truncates to the specified length with end-of-text tokens.fn decode(&self, tokens: &[usize]) -> String-- Decodes token IDs back to text, replacing</w>markers with spaces.
CLIP Text Encoder
The CLIP text model is implemented as a standard transformer encoder with these components:
ClipTextEmbeddings:
struct ClipTextEmbeddings {
token_embedding: nn::Embedding,
position_embedding: nn::Embedding,
position_ids: Tensor,
}
Combines token embeddings (vocab size 49,408, dim 768) with learned positional embeddings (max length 77).
ClipAttention:
struct ClipAttention {
k_proj: nn::Linear,
v_proj: nn::Linear,
q_proj: nn::Linear,
out_proj: nn::Linear,
head_dim: i64,
scale: f64,
num_attention_heads: i64,
}
Multi-head self-attention with separate Q, K, V projections and a causal attention mask.
ClipEncoderLayer -- Combines ClipAttention and ClipMlp (using quick_gelu activation) with layer normalization in a pre-norm arrangement.
ClipEncoder -- A stack of NUM_HIDDEN_LAYERS (12) encoder layers.
ClipTextTransformer:
struct ClipTextTransformer {
embeddings: ClipTextEmbeddings,
encoder: ClipEncoder,
final_layer_norm: nn::LayerNorm,
}
The complete CLIP text model. The forward method applies embeddings, runs through all encoder layers, and applies final layer normalization. Outputs shape [batch, 77, 768].
Variational Autoencoder (VAE)
AutoEncoderKL:
struct AutoEncoderKL {
encoder: Encoder,
decoder: Decoder,
quant_conv: nn::Conv2D,
post_quant_conv: nn::Conv2D,
config: AutoEncoderKLConfig,
}
The VAE maps between pixel space (512x512x3) and latent space (64x64x4).
Key methods:
fn new(vs, in_channels, out_channels, config) -> Self-- Constructs encoder, decoder, and quantization convolutions.fn encode(&self, xs: &Tensor) -> Tensor-- Encodes an image into latent distribution parameters.fn decode(&self, xs: &Tensor) -> Tensor-- Decodes latent vectors back to pixel space viapost_quant_conv -> decoder.
AutoEncoderKLConfig:
struct AutoEncoderKLConfig {
block_out_channels: Vec<i64>, // default: [128, 256, 512, 512]
layers_per_block: i64, // default: 2
latent_channels: i64, // default: 4
norm_num_groups: i64, // default: 32
}
The Encoder and Decoder are built from DownEncoderBlock2D and UpDecoderBlock2D blocks respectively, each containing ResNet blocks with group normalization. A UNetMidBlock2D sits between the down and up paths.
UNet with Cross-Attention
UNet2DConditionModel:
struct UNet2DConditionModel {
conv_in: nn::Conv2D,
time_proj: Timesteps,
time_embedding: TimestepEmbedding,
down_blocks: Vec<UNetDownBlock>,
mid_block: UNetMidBlock2DCrossAttn,
up_blocks: Vec<UNetUpBlock>,
conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2D,
config: UNet2DConditionModelConfig,
}
The noise prediction network conditioned on text embeddings and diffusion timestep.
Architecture details:
- Input: 4-channel latent tensor of shape
[B, 4, 64, 64] - Time embedding: Sinusoidal
Timestepsprojected through a 2-layer MLP (TimestepEmbedding) to dimension4 * base_channels - Down blocks: 4 blocks with channel dimensions [320, 640, 1280, 1280]. The first 3 use cross-attention (
CrossAttnDownBlock2D); the last is a basicDownBlock2D - Mid block:
UNetMidBlock2DCrossAttnat 1280 channels with cross-attention to text embeddings - Up blocks: 4 blocks mirroring the down blocks with skip connections
- Output: 4-channel tensor, same spatial dimensions as input
Key methods:
fn forward(&self, xs: &Tensor, timestep: f64, encoder_hidden_states: &Tensor) -> Tensor-- Runs the full UNet pass with timestep conditioning and cross-attention to text embeddings. Manages skip connections between down and up blocks.
Supporting types:
CrossAttention-- Implements both self-attention and cross-attention (when the context dimension differs from the query dimension)BasicTransformerBlock-- Combines self-attention, cross-attention, and feed-forward with layer normsSpatialTransformer-- Applies transformer blocks to spatial feature maps (reshape to sequence, transform, reshape back)ResnetBlock2D-- ResNet block with time embedding injection and optional shortcut projectionGeGlu-- Gated linear unit with GELU activation, used in feed-forward layers
DDIM Scheduler
struct DDIMScheduler {
timesteps: Vec<usize>,
alphas_cumprod: Vec<f64>,
step_ratio: usize,
config: DDIMSchedulerConfig,
}
Implements the Denoising Diffusion Implicit Models (DDIM) scheduler for accelerated sampling.
DDIMSchedulerConfig:
struct DDIMSchedulerConfig {
beta_start: f64, // default: 0.00085
beta_end: f64, // default: 0.012
beta_schedule: BetaSchedule, // default: ScaledLinear
eta: f64, // default: 0.0 (deterministic)
}
Key methods:
fn new(inference_steps, train_timesteps, config) -> Self-- Computes the noise schedule. SupportsLinearandScaledLinearbeta schedules. Computes cumulative alpha products for all training timesteps.fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor-- Performs one DDIM denoising step. Wheneta = 0the process is deterministic; wheneta > 0stochastic noise is injected.
Pipeline Builder Functions
fn build_clip_transformer(device: Device) -> anyhow::Result<ClipTextTransformer> fn build_vae(device: Device) -> anyhow::Result<AutoEncoderKL> fn build_unet(device: Device) -> anyhow::Result<UNet2DConditionModel>
Each function creates a VarStore on the specified device, constructs the model with the Stable Diffusion v1.4 configuration, and loads pre-trained weights from .ot files in the data/ directory.
Main Entry Point
fn main() -> anyhow::Result<()>
Executes the full text-to-image pipeline:
- Detects CUDA availability; accepts an optional text prompt and "cpu" flag from command-line arguments
- Creates a DDIM scheduler with 30 inference steps and 1000 training timesteps
- Tokenizes the prompt (and an empty unconditional prompt) to 77 tokens each
- Builds and loads the CLIP text encoder; computes text embeddings for both prompts
- Builds and loads the VAE and UNet
- Initializes random latents of shape
[1, 4, 64, 64]with seed 32 - Iterates through scheduler timesteps:
- Duplicates latents for classifier-free guidance (conditional + unconditional)
- Predicts noise with the UNet conditioned on text embeddings
- Applies guidance:
noise = noise_uncond + scale * (noise_text - noise_uncond) - Updates latents via the DDIM scheduler step
- Decodes latents through the VAE, scales to [0, 255] uint8, and saves as
sd_{step}.png
I/O Contract
Inputs
- Text prompt: A string describing the desired image (default: "A rusty robot holding a fire torch in its hand")
- Device selection: Pass "cpu" as a command-line argument to force CPU inference; otherwise CUDA is used if available
- Pre-trained weights: Three
.otweight files must be present indata/:data/pytorch_model.ot-- CLIP text encoder weightsdata/vae.ot-- VAE encoder/decoder weightsdata/unet.ot-- UNet weights
- Vocabulary file:
data/bpe_simple_vocab_16e6.txtfor the BPE tokenizer
Outputs
- Generated images: One PNG file per denoising step, named
sd_0.pngthroughsd_29.png, each 512x512 pixels RGB - Console output: CUDA/cuDNN availability, decoded prompt text, token tensor, and per-step progress
Invariants
- Token sequences are always padded or truncated to exactly 77 tokens (MAX_POSITION_EMBEDDINGS)
- Latents operate at 1/8 spatial resolution (64x64 for 512x512 output)
- The VAE scaling factor of 0.18215 is applied before decoding:
latents / 0.18215 - Classifier-free guidance requires two forward passes per step (unconditional and conditional)
- All inference runs inside
tch::no_grad_guard()to disable gradient computation - Deterministic generation is ensured by
tch::manual_seed(32)
Pipeline Data Flow
Text Prompt
|
v
[BPE Tokenizer] --> token IDs (77 integers)
|
v
[CLIP Text Encoder] --> text embeddings [1, 77, 768]
| + unconditional embeddings [1, 77, 768]
| --> concatenated [2, 77, 768]
|
v
[Random Latents] --> z ~ N(0,1), shape [1, 4, 64, 64]
|
v
[Denoising Loop (30 steps)]
|
| For each timestep:
| 1. Duplicate latents --> [2, 4, 64, 64]
| 2. UNet(latents, timestep, text_embeddings) --> noise prediction [2, 4, 64, 64]
| 3. Split into unconditional and conditional noise
| 4. Apply classifier-free guidance
| 5. DDIM step to update latents
| 6. VAE decode --> image [1, 3, 512, 512]
| 7. Save as PNG
|
v
Output images (sd_0.png ... sd_29.png)
Dependencies
tch-- Rust bindings to libtorch for tensor operations, neural network modules (nn), convolutions, attention, and model weight loadingregex-- Regular expression matching for BPE tokenizationanyhow-- Error handling throughout the pipelinestd::collections::{HashMap, HashSet}-- Vocabulary and BPE rank storagestd::io::BufRead-- Reading the BPE vocabulary file
Weight Preparation
Pre-trained weights must be converted from PyTorch .bin format to tch-rs .ot format:
- Download PyTorch weights from Hugging Face (CLIP, VAE, UNet)
- Convert to NumPy
.npzfiles using a Python script - Convert to
.otformat using thetensor-toolsexample:cargo run --release --example tensor-tools cp ./data/model.npz ./data/model.ot
Related Pages
- Principle:LaurentMazare_Tch_rs_Stable_Diffusion -- The guiding principle behind the Stable Diffusion architecture and its text-to-image generation approach