Principle:Sail sg LongSpec Tree Speculative Decoding
| Knowledge Sources | |
|---|---|
| Domains | Speculative_Decoding, LLM_Inference, Tree_Search |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Inference acceleration technique that generates a tree of candidate token sequences using a lightweight draft model, then verifies the entire tree in a single target LLM forward pass.
Description
Tree Speculative Decoding extends standard speculative decoding by generating multiple candidate continuations organized in a tree structure rather than a single linear sequence. At each level, the draft model proposes multiple tokens (top-k candidates), creating a branching tree of possibilities. The target LLM then verifies all candidates simultaneously using tree-structured attention.
This approach is particularly effective for long-context scenarios where:
- The draft model's uncertainty increases with context length, making single-sequence speculation less effective
- Tree expansion provides exponentially more candidate paths with only linear computational overhead
- The tree structure allows accepting the longest valid prefix from any branch
The method supports three generation modes:
- Tree mode (tree_spec_generate): Full tree expansion with configurable branching factors per level
- Sequential mode (spec_generate): Linear chain of gamma speculative tokens (standard speculative decoding)
- Vanilla mode (vanilla_generate): Baseline autoregressive decoding without speculation
Usage
Use tree speculative decoding when serving LLMs for inference and latency is the primary concern. It is most beneficial when:
- The target LLM is large (7B+ parameters) and inference-bound
- Input contexts are long (4k+ tokens) where draft accuracy may degrade
- The GLIDE draft model is available and trained for the target LLM
Tree mode with shape [4, 16, 16, 16, 16] is the default configuration, generating up to 4 candidates at the first level and 16 at each subsequent level.
Theoretical Basis
Tree Construction:
The tree is built level by level. At level l, the draft model generates k_l candidate tokens for each node at level l-1. The tree shape is defined as a list of branching factors:
tree_shape = [4, 16, 16, 16, 16]
# Level 0: 4 candidates from root
# Level 1: 16 candidates per node (64 total)
# Level 2: 16 candidates per node (top-16 by score)
# ...
Candidate scoring combines parent log-probability with child log-probability (beam-tree strategy):
Verification:
The target LLM processes all tree nodes in a single forward pass using a tree-structured attention mask. For greedy decoding (temperature=0):
- A candidate is accepted if the target LLM agrees with the draft prediction at the parent node
- The longest accepted path from root determines how many tokens are accepted
- The first rejected position gets resampled from the target distribution
For stochastic decoding (temperature>0), rejection sampling is used: