Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:LaurentMazare Tch rs Neural Style Transfer

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Computer Vision, Style Transfer
Last Updated 2026-02-08 00:00 GMT

Overview

Implements neural style transfer using VGG16 feature extraction, combining the content of one image with the artistic style of another through iterative optimization.

Description

This example is a Rust port of the PyTorch neural style transfer tutorial. It uses a pre-trained VGG16 network as a fixed feature extractor and optimizes a generated image to minimize a combined style and content loss.

Key components:

  • Gram matrix computation: The gram_matrix function reshapes a feature map of shape (a, b, c, d) into a 2D matrix (a*b, c*d), computes the matrix product with its transpose, and normalizes by the total number of elements. This captures texture/style information.
  • Style loss: Computed as the MSE between gram matrices of the generated image and style image at five VGG16 layers (indices 0, 2, 5, 7, 10), weighted by STYLE_WEIGHT = 1e6.
  • Content loss: Computed as the MSE between feature maps of the generated image and content image at one VGG16 layer (index 7).
  • Optimization: The generated image is initialized as a copy of the content image and stored as a trainable variable. The Adam optimizer (learning rate 0.1) updates the pixel values directly for 3000 steps. The VGG16 network is frozen (net_vs.freeze()) so only the image variable is updated.
  • Output: Every 1000 steps, the current generated image is saved as a JPEG file using imagenet::save_image.

Usage

Use this example to learn about feature-based loss functions, gram matrix style representations, and optimizing image tensors directly. Requires pre-trained VGG16 weights downloadable from the tch-rs releases.

Code Reference

Source Location

Signature

fn gram_matrix(m: &Tensor) -> Tensor

fn style_loss(m1: &Tensor, m2: &Tensor) -> Tensor

pub fn main() -> Result<()>

Import

// Standalone binary example. Run with:
// cargo run --example neural-style-transfer -- style.jpg content.jpg vgg16.ot
use anyhow::{bail, Result};
use tch::vision::{imagenet, vgg};
use tch::{nn, nn::OptimizerConfig, Device, Tensor};

I/O Contract

Inputs

Name Type Required Description
style.jpg CLI argument (file path) Yes The style reference image.
content.jpg CLI argument (file path) Yes The content reference image.
vgg16.ot CLI argument (file path) Yes Pre-trained VGG16 weights file.

Outputs

Name Type Description
outN.jpg Image files Generated images saved at steps 1000, 2000, and 3000.
stdout Text Step index and total loss printed every 1000 steps.

Usage Examples

use anyhow::{bail, Result};
use tch::vision::{imagenet, vgg};
use tch::{nn, nn::OptimizerConfig, Device, Tensor};

const STYLE_WEIGHT: f64 = 1e6;
const LEARNING_RATE: f64 = 1e-1;
const TOTAL_STEPS: i64 = 3000;
const STYLE_INDEXES: [usize; 5] = [0, 2, 5, 7, 10];
const CONTENT_INDEXES: [usize; 1] = [7];

fn gram_matrix(m: &Tensor) -> Tensor {
    let (a, b, c, d) = m.size4().unwrap();
    let m = m.view([a * b, c * d]);
    let g = m.matmul(&m.tr());
    g / (a * b * c * d)
}

fn style_loss(m1: &Tensor, m2: &Tensor) -> Tensor {
    gram_matrix(m1).mse_loss(&gram_matrix(m2), tch::Reduction::Mean)
}

pub fn main() -> Result<()> {
    let device = Device::cuda_if_available();

    // Load pre-trained VGG16 and freeze weights
    let mut net_vs = tch::nn::VarStore::new(device);
    let net = vgg::vgg16(&net_vs.root(), imagenet::CLASS_COUNT);
    net_vs.load(&weights)?;
    net_vs.freeze();

    // Extract style and content features
    let style_layers = net.forward_all_t(&style_img, false, Some(max_layer));
    let content_layers = net.forward_all_t(&content_img, false, Some(max_layer));

    // Optimize the generated image
    let vs = nn::VarStore::new(device);
    let input_var = vs.root().var_copy("img", &content_img);
    let mut opt = nn::Adam::default().build(&vs, LEARNING_RATE)?;

    for step_idx in 1..(1 + TOTAL_STEPS) {
        let input_layers = net.forward_all_t(&input_var, false, Some(max_layer));
        let style_loss: Tensor =
            STYLE_INDEXES.iter().map(|&i| style_loss(&input_layers[i], &style_layers[i])).sum();
        let content_loss: Tensor = CONTENT_INDEXES
            .iter()
            .map(|&i| input_layers[i].mse_loss(&content_layers[i], tch::Reduction::Mean))
            .sum();
        let loss = style_loss * STYLE_WEIGHT + content_loss;
        opt.backward_step(&loss);
    }
    Ok(())
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment