Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Romsto Speculative Decoding NGramStorage

From Leeroopedia
Knowledge Sources
Domains NLP, Data_Structures, Statistical_Language_Modeling
Last Updated 2026-02-14 04:30 GMT

Overview

Concrete tool for storing and querying n-gram frequency counts as a lightweight drafter for n-gram assisted speculative decoding.

Description

The ngram_storage module provides three classes forming an inheritance hierarchy:

  • INgramStorage (ABC, L5-69): Defines the interface with abstract methods: next_token, has_gram, update, initialize, reset.
  • OneLevelNGramStorage (L73-150): Single-level implementation using exact (n-1)-token contexts only. Uses Python dicts for counts and best-token tracking.
  • NGramStorage (L154-249): Multi-level implementation that stores k-grams for all k in [2, n]. On lookup, tries the longest context first and falls back to shorter contexts. Recommended for NASD as it provides better coverage.

Both concrete classes maintain two dictionaries: counts (context -> token -> frequency) and ngrams (context -> most frequent token). Updates track the running maximum to avoid recomputing argmax on every query.

Usage

Import NGramStorage (multi-level, recommended) or OneLevelNGramStorage (single-level) when setting up n-gram assisted speculative decoding. Create an instance with n (n-gram order, must be > 1) and vocab_size (from model.config.vocab_size). The storage is initialized from the prompt via initialize() and updated during generation via update().

Code Reference

Source Location

Signature

class INgramStorage(abc.ABC):
    """Interface for N-gram storage."""

    def __init__(self, n: int, vocab_size: int):
        """n must be > 1. vocab_size is the model vocabulary size."""

    @abc.abstractmethod
    def next_token(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Predict next token. Returns (token, known) tensors."""

    @abc.abstractmethod
    def has_gram(self, ngram: torch.Tensor) -> bool:
        """Check if ngram has been observed."""

    @abc.abstractmethod
    def update(self, input_ids: torch.Tensor, next_tokens: torch.Tensor):
        """Record new (context, token) observations."""

    @abc.abstractmethod
    def initialize(self, input_ids: torch.Tensor):
        """Seed storage with all n-grams from input."""

    @abc.abstractmethod
    def reset(self):
        """Clear all stored n-grams."""


class OneLevelNGramStorage(INgramStorage):
    """Single-level: uses only exact (n-1)-token contexts."""
    def __init__(self, n: int, vocab_size: int): ...


class NGramStorage(INgramStorage):
    """Multi-level: tries longest context first, falls back to shorter."""
    def __init__(self, n: int, vocab_size: int): ...

Import

from ngram_assisted import NGramStorage, OneLevelNGramStorage, INgramStorage

I/O Contract

Inputs (Constructor)

Name Type Required Description
n int Yes N-gram order (must be > 1). E.g., 3 for trigrams. Determines the maximum context length (n-1 tokens).
vocab_size int Yes Vocabulary size from model.config.vocab_size. Used for random fallback when no prediction exists.

Key Methods

Method Inputs Outputs Description
initialize(input_ids) torch.Tensor (batch_size, seq_len) None Seeds storage with all n-grams from the input prompt
next_token(input_ids) torch.Tensor (batch_size, seq_len) Tuple[Tensor, Tensor] — (predicted_token, is_known) Predicts next token from context. Returns random token if unknown.
update(input_ids, next_tokens) input_ids: (batch_size, seq_len), next_tokens: (batch_size, 1+) None Records new observations and updates most-frequent tracking
has_gram(ngram) torch.Tensor (n,) bool Checks if specific n-gram has been observed
reset() None None Clears all stored counts and predictions

Usage Examples

Multi-Level NGramStorage (Recommended)

from ngram_assisted import NGramStorage
import torch

# Create storage for trigrams with vocab size 128256 (Llama 3.2)
storage = NGramStorage(n=3, vocab_size=128256)

# Initialize from prompt tokens
prompt_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]])  # shape (1, seq_len)
storage.initialize(prompt_ids)

# Predict next token
predicted, known = storage.next_token(prompt_ids)
print(f"Predicted: {predicted}, Known: {known}")

# Update with new observation
new_tokens = torch.tensor([[9]])
storage.update(prompt_ids, new_tokens)

# Reset between generations
storage.reset()

Single-Level OneLevelNGramStorage

from ngram_assisted import OneLevelNGramStorage

# Single-level uses only exact (n-1)-token context
storage = OneLevelNGramStorage(n=4, vocab_size=128256)
storage.initialize(prompt_ids)

# Only matches exact 3-token contexts (no fallback)
predicted, known = storage.next_token(prompt_ids)

Related Pages

Implements Principle

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment