Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:LaurentMazare Tch rs Nn Embedding

From Leeroopedia


Knowledge Sources
Domains NLP, Neural_Network_Layers
Last Updated 2026-02-08 14:00 GMT

Overview

Concrete tool for creating token embedding lookup layers provided by the tch nn module.

Description

nn::embedding creates an Embedding struct with a weight matrix of shape [num_embeddings, embedding_dim] registered in the VarStore. Configuration options include sparse gradients, gradient scaling by frequency, and padding index. The Embedding implements the Module trait, using Tensor::embedding for the forward pass.

Usage

Use as the first layer of language models. For LLaMA: num_embeddings=32000 (vocab), embedding_dim=4096 (hidden).

Code Reference

Source Location

  • Repository: tch-rs
  • File: src/nn/sparse.rs
  • Lines: 35-43

Signature

pub fn embedding<'a, T: Borrow<Path<'a>>>(
    vs: T,
    num_embeddings: i64,
    embedding_dim: i64,
    config: EmbeddingConfig,
) -> Embedding

Import

use tch::nn;

I/O Contract

Inputs

Name Type Required Description
vs T: Borrow<Path> Yes VarStore path for parameter registration
num_embeddings i64 Yes Vocabulary size
embedding_dim i64 Yes Embedding vector dimension
config EmbeddingConfig Yes Config: sparse, scale_grad_by_freq, padding_idx

Outputs

Name Type Description
Embedding nn::Embedding Struct with ws: Tensor[vocab_size, embedding_dim], implementing Module

Usage Examples

use tch::{nn, nn::Module};

let vs = nn::VarStore::new(tch::Device::Cpu);
let emb = nn::embedding(vs.root() / "wte", 32000, 4096, Default::default());

let token_ids = tch::Tensor::from_slice(&[1i64, 2, 3]).unsqueeze(0);
let embeddings = emb.forward(&token_ids);  // [1, 3, 4096]

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment