Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:LaurentMazare Tch rs Char RNN

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Natural Language Processing, Recurrent Neural Networks
Last Updated 2026-02-08 00:00 GMT

Overview

Implements a character-level language model using an LSTM network to generate Shakespeare-like text, inspired by Karpathy's char-rnn.

Description

This example trains an LSTM-based character-level language model on the tinyshakespeare dataset. The architecture consists of:

  • An LSTM layer with a hidden size of 256, taking one-hot encoded character inputs.
  • A linear output layer that maps the LSTM hidden state to a probability distribution over the character vocabulary.

Training proceeds for 100 epochs with batch size 256 and sequence length 180. Each batch is obtained by shuffling and slicing the text data. The input characters are one-hot encoded, fed through the LSTM, and the output logits are trained using cross-entropy loss with the Adam optimizer (learning rate 0.01). Gradient clipping at 0.5 is applied via opt.backward_step_clip.

Text generation (sampling) works autoregressively: starting from a zero input, each character is fed through the LSTM, the output is passed through softmax, and the next character is sampled via multinomial. The generated text is sampled for 1024 characters at the end of each epoch.

Usage

Use this example to learn how to build recurrent neural networks for sequential text generation tasks with tch-rs. It demonstrates the TextData utility, LSTM cell stepping, one-hot encoding, and autoregressive sampling.

Code Reference

Source Location

Signature

fn sample(data: &TextData, lstm: &LSTM, linear: &Linear, device: Device) -> String

pub fn main() -> Result<()>

Import

// Standalone binary example. Run with:
// cargo run --example char-rnn
use anyhow::Result;
use tch::data::TextData;
use tch::nn::{Linear, Module, OptimizerConfig, LSTM, RNN};
use tch::{nn, Device, Kind, Tensor};

I/O Contract

Inputs

Name Type Required Description
data/input.txt File (text) Yes A plain text file (e.g., tinyshakespeare) used as the training corpus.

Outputs

Name Type Description
stdout Text Per-epoch training loss and sampled text of 1024 characters.

Usage Examples

use anyhow::Result;
use tch::data::TextData;
use tch::nn::{Linear, Module, OptimizerConfig, LSTM, RNN};
use tch::{nn, Device, Kind, Tensor};

const LEARNING_RATE: f64 = 0.01;
const HIDDEN_SIZE: i64 = 256;
const SEQ_LEN: i64 = 180;
const BATCH_SIZE: i64 = 256;
const EPOCHS: i64 = 100;
const SAMPLING_LEN: i64 = 1024;

fn sample(data: &TextData, lstm: &LSTM, linear: &Linear, device: Device) -> String {
    let labels = data.labels();
    let mut state = lstm.zero_state(1);
    let mut last_label = 0i64;
    let mut result = String::new();
    for _index in 0..SAMPLING_LEN {
        let input = Tensor::zeros([1, labels], (Kind::Float, device));
        let _ = input.narrow(1, last_label, 1).fill_(1.0);
        state = lstm.step(&input, &state);
        let sampled_y = linear
            .forward(&state.h())
            .squeeze_dim(0)
            .softmax(-1, Kind::Float)
            .multinomial(1, false);
        last_label = i64::try_from(sampled_y).unwrap();
        result.push(data.label_to_char(last_label))
    }
    result
}

pub fn main() -> Result<()> {
    let device = Device::cuda_if_available();
    let vs = nn::VarStore::new(device);
    let data = TextData::new("data/input.txt")?;
    let labels = data.labels();
    let lstm = nn::lstm(vs.root(), labels, HIDDEN_SIZE, Default::default());
    let linear = nn::linear(vs.root(), HIDDEN_SIZE, labels, Default::default());
    let mut opt = nn::Adam::default().build(&vs, LEARNING_RATE)?;
    for epoch in 1..(1 + EPOCHS) {
        let mut sum_loss = 0.;
        let mut cnt_loss = 0.;
        for batch in data.iter_shuffle(SEQ_LEN + 1, BATCH_SIZE) {
            let xs_onehot = batch.narrow(1, 0, SEQ_LEN).onehot(labels);
            let ys = batch.narrow(1, 1, SEQ_LEN).to_kind(Kind::Int64);
            let (lstm_out, _) = lstm.seq(&xs_onehot.to_device(device));
            let logits = linear.forward(&lstm_out);
            let loss = logits
                .view([BATCH_SIZE * SEQ_LEN, labels])
                .cross_entropy_for_logits(&ys.to_device(device).view([BATCH_SIZE * SEQ_LEN]));
            opt.backward_step_clip(&loss, 0.5);
            sum_loss += f64::try_from(loss)?;
            cnt_loss += 1.0;
        }
        println!("Epoch: {}   loss: {:5.3}", epoch, sum_loss / cnt_loss);
        println!("Sample: {}", sample(&data, &lstm, &linear, device));
    }
    Ok(())
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment