Implementation:LaurentMazare Tch rs Char RNN
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Natural Language Processing, Recurrent Neural Networks |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements a character-level language model using an LSTM network to generate Shakespeare-like text, inspired by Karpathy's char-rnn.
Description
This example trains an LSTM-based character-level language model on the tinyshakespeare dataset. The architecture consists of:
- An LSTM layer with a hidden size of 256, taking one-hot encoded character inputs.
- A linear output layer that maps the LSTM hidden state to a probability distribution over the character vocabulary.
Training proceeds for 100 epochs with batch size 256 and sequence length 180. Each batch is obtained by shuffling and slicing the text data. The input characters are one-hot encoded, fed through the LSTM, and the output logits are trained using cross-entropy loss with the Adam optimizer (learning rate 0.01). Gradient clipping at 0.5 is applied via opt.backward_step_clip.
Text generation (sampling) works autoregressively: starting from a zero input, each character is fed through the LSTM, the output is passed through softmax, and the next character is sampled via multinomial. The generated text is sampled for 1024 characters at the end of each epoch.
Usage
Use this example to learn how to build recurrent neural networks for sequential text generation tasks with tch-rs. It demonstrates the TextData utility, LSTM cell stepping, one-hot encoding, and autoregressive sampling.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: examples/char-rnn/main.rs
- Lines: 1-68
Signature
fn sample(data: &TextData, lstm: &LSTM, linear: &Linear, device: Device) -> String
pub fn main() -> Result<()>
Import
// Standalone binary example. Run with:
// cargo run --example char-rnn
use anyhow::Result;
use tch::data::TextData;
use tch::nn::{Linear, Module, OptimizerConfig, LSTM, RNN};
use tch::{nn, Device, Kind, Tensor};
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| data/input.txt | File (text) | Yes | A plain text file (e.g., tinyshakespeare) used as the training corpus. |
Outputs
| Name | Type | Description |
|---|---|---|
| stdout | Text | Per-epoch training loss and sampled text of 1024 characters. |
Usage Examples
use anyhow::Result;
use tch::data::TextData;
use tch::nn::{Linear, Module, OptimizerConfig, LSTM, RNN};
use tch::{nn, Device, Kind, Tensor};
const LEARNING_RATE: f64 = 0.01;
const HIDDEN_SIZE: i64 = 256;
const SEQ_LEN: i64 = 180;
const BATCH_SIZE: i64 = 256;
const EPOCHS: i64 = 100;
const SAMPLING_LEN: i64 = 1024;
fn sample(data: &TextData, lstm: &LSTM, linear: &Linear, device: Device) -> String {
let labels = data.labels();
let mut state = lstm.zero_state(1);
let mut last_label = 0i64;
let mut result = String::new();
for _index in 0..SAMPLING_LEN {
let input = Tensor::zeros([1, labels], (Kind::Float, device));
let _ = input.narrow(1, last_label, 1).fill_(1.0);
state = lstm.step(&input, &state);
let sampled_y = linear
.forward(&state.h())
.squeeze_dim(0)
.softmax(-1, Kind::Float)
.multinomial(1, false);
last_label = i64::try_from(sampled_y).unwrap();
result.push(data.label_to_char(last_label))
}
result
}
pub fn main() -> Result<()> {
let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device);
let data = TextData::new("data/input.txt")?;
let labels = data.labels();
let lstm = nn::lstm(vs.root(), labels, HIDDEN_SIZE, Default::default());
let linear = nn::linear(vs.root(), HIDDEN_SIZE, labels, Default::default());
let mut opt = nn::Adam::default().build(&vs, LEARNING_RATE)?;
for epoch in 1..(1 + EPOCHS) {
let mut sum_loss = 0.;
let mut cnt_loss = 0.;
for batch in data.iter_shuffle(SEQ_LEN + 1, BATCH_SIZE) {
let xs_onehot = batch.narrow(1, 0, SEQ_LEN).onehot(labels);
let ys = batch.narrow(1, 1, SEQ_LEN).to_kind(Kind::Int64);
let (lstm_out, _) = lstm.seq(&xs_onehot.to_device(device));
let logits = linear.forward(&lstm_out);
let loss = logits
.view([BATCH_SIZE * SEQ_LEN, labels])
.cross_entropy_for_logits(&ys.to_device(device).view([BATCH_SIZE * SEQ_LEN]));
opt.backward_step_clip(&loss, 0.5);
sum_loss += f64::try_from(loss)?;
cnt_loss += 1.0;
}
println!("Epoch: {} loss: {:5.3}", epoch, sum_loss / cnt_loss);
println!("Sample: {}", sample(&data, &lstm, &linear, device));
}
Ok(())
}