Implementation:LaurentMazare Tch rs Llama New
| Knowledge Sources | |
|---|---|
| Domains | NLP, Model_Architecture |
| Last Updated | 2026-02-08 14:00 GMT |
Overview
Concrete tool for constructing the LLaMA transformer model architecture in Rust provided by the tch-rs examples.
Description
Llama::new builds the complete LLaMA model: token embedding (wte), N transformer blocks (each with RmsNorm, CausalSelfAttention, MLP), final RmsNorm (ln_f), and output linear head (lm_head with no bias). Each block contains multi-head attention with RoPE and SwiGLU feed-forward. The llama wrapper function additionally precomputes rotary embeddings and wraps the model in a closure that applies temperature scaling and softmax.
Usage
Use to build a LLaMA model for text generation. After construction, load converted safetensors weights via the mmap loading pattern.
Code Reference
Source Location
- Repository: tch-rs
- File: examples/llama/main.rs
- Lines: 244-259 (Llama::new), 261-271 (forward), 288-296 (wrapper)
Signature
impl Llama {
fn new(vs: nn::Path, config: &Config) -> Self
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Tensor
}
Import
// Internal to examples/llama/main.rs
// Uses tch::nn for layer construction
use tch::nn;
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| vs | nn::Path | Yes | VarStore path for parameter registration |
| config | &Config | Yes | Model config (n_layer, n_head, n_embd, vocab_size) |
Outputs
| Name | Type | Description |
|---|---|---|
| Llama | struct | Model with wte (Embedding), blocks (Vec<Block>), ln_f (RmsNorm), lm_head (Linear) |
Usage Examples
use tch::nn;
let vs = nn::VarStore::new(tch::Device::Cpu);
let config = Config::config_7b();
let llama = Llama::new(vs.root(), &config);
// After loading weights:
let logits = llama.forward(&token_ids, &freqs_cis); // [1, 1, 32000]