Implementation:Romsto Speculative Decoding NGramStorage
| Knowledge Sources | |
|---|---|
| Domains | NLP, Data_Structures, Statistical_Language_Modeling |
| Last Updated | 2026-02-14 04:30 GMT |
Overview
Concrete tool for storing and querying n-gram frequency counts as a lightweight drafter for n-gram assisted speculative decoding.
Description
The ngram_storage module provides three classes forming an inheritance hierarchy:
- INgramStorage (ABC, L5-69): Defines the interface with abstract methods: next_token, has_gram, update, initialize, reset.
- OneLevelNGramStorage (L73-150): Single-level implementation using exact (n-1)-token contexts only. Uses Python dicts for counts and best-token tracking.
- NGramStorage (L154-249): Multi-level implementation that stores k-grams for all k in [2, n]. On lookup, tries the longest context first and falls back to shorter contexts. Recommended for NASD as it provides better coverage.
Both concrete classes maintain two dictionaries: counts (context -> token -> frequency) and ngrams (context -> most frequent token). Updates track the running maximum to avoid recomputing argmax on every query.
Usage
Import NGramStorage (multi-level, recommended) or OneLevelNGramStorage (single-level) when setting up n-gram assisted speculative decoding. Create an instance with n (n-gram order, must be > 1) and vocab_size (from model.config.vocab_size). The storage is initialized from the prompt via initialize() and updated during generation via update().
Code Reference
Source Location
- Repository: Speculative-Decoding
- File: ngram_assisted/ngram_storage.py
- Lines: L5-249
Signature
class INgramStorage(abc.ABC):
"""Interface for N-gram storage."""
def __init__(self, n: int, vocab_size: int):
"""n must be > 1. vocab_size is the model vocabulary size."""
@abc.abstractmethod
def next_token(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predict next token. Returns (token, known) tensors."""
@abc.abstractmethod
def has_gram(self, ngram: torch.Tensor) -> bool:
"""Check if ngram has been observed."""
@abc.abstractmethod
def update(self, input_ids: torch.Tensor, next_tokens: torch.Tensor):
"""Record new (context, token) observations."""
@abc.abstractmethod
def initialize(self, input_ids: torch.Tensor):
"""Seed storage with all n-grams from input."""
@abc.abstractmethod
def reset(self):
"""Clear all stored n-grams."""
class OneLevelNGramStorage(INgramStorage):
"""Single-level: uses only exact (n-1)-token contexts."""
def __init__(self, n: int, vocab_size: int): ...
class NGramStorage(INgramStorage):
"""Multi-level: tries longest context first, falls back to shorter."""
def __init__(self, n: int, vocab_size: int): ...
Import
from ngram_assisted import NGramStorage, OneLevelNGramStorage, INgramStorage
I/O Contract
Inputs (Constructor)
| Name | Type | Required | Description |
|---|---|---|---|
| n | int | Yes | N-gram order (must be > 1). E.g., 3 for trigrams. Determines the maximum context length (n-1 tokens). |
| vocab_size | int | Yes | Vocabulary size from model.config.vocab_size. Used for random fallback when no prediction exists. |
Key Methods
| Method | Inputs | Outputs | Description |
|---|---|---|---|
| initialize(input_ids) | torch.Tensor (batch_size, seq_len) | None | Seeds storage with all n-grams from the input prompt |
| next_token(input_ids) | torch.Tensor (batch_size, seq_len) | Tuple[Tensor, Tensor] — (predicted_token, is_known) | Predicts next token from context. Returns random token if unknown. |
| update(input_ids, next_tokens) | input_ids: (batch_size, seq_len), next_tokens: (batch_size, 1+) | None | Records new observations and updates most-frequent tracking |
| has_gram(ngram) | torch.Tensor (n,) | bool | Checks if specific n-gram has been observed |
| reset() | None | None | Clears all stored counts and predictions |
Usage Examples
Multi-Level NGramStorage (Recommended)
from ngram_assisted import NGramStorage
import torch
# Create storage for trigrams with vocab size 128256 (Llama 3.2)
storage = NGramStorage(n=3, vocab_size=128256)
# Initialize from prompt tokens
prompt_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) # shape (1, seq_len)
storage.initialize(prompt_ids)
# Predict next token
predicted, known = storage.next_token(prompt_ids)
print(f"Predicted: {predicted}, Known: {known}")
# Update with new observation
new_tokens = torch.tensor([[9]])
storage.update(prompt_ids, new_tokens)
# Reset between generations
storage.reset()
Single-Level OneLevelNGramStorage
from ngram_assisted import OneLevelNGramStorage
# Single-level uses only exact (n-1)-token context
storage = OneLevelNGramStorage(n=4, vocab_size=128256)
storage.initialize(prompt_ids)
# Only matches exact 3-token contexts (no fallback)
predicted, known = storage.next_token(prompt_ids)