Implementation:Sail sg LongSpec Tree Verification Accept
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Speculative_Decoding, LLM_Inference |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Concrete tool for deciding which draft tokens to accept during tree speculative decoding, with greedy and stochastic verification modes and KV cache rearrangement.
Description
Two @torch.compile-decorated methods on LlamaGlide and Qwen2Glide:
- tree_verification: Greedy acceptance — compares target argmax with draft tokens along tree paths, finds deepest accepted node, rearranges KV cache
- verify_stochastic: Stochastic acceptance — uses rejection sampling with probability ratio min(1, P_target/P_draft), guaranteeing output distribution matches target
Both methods are JIT-compiled via @torch.compile for performance.
Usage
Called internally by tree_spec_generate and spec_generate after target model verification. Not called directly by users.
Code Reference
Source Location
- Repository: LongSpec
- File (Llama): longspec/test/llama_glide.py
- Lines (tree_verification): L1128-1175
- Lines (verify_stochastic): L1177-1245
- File (Qwen2): longspec/test/qwen2_glide.py
- Lines (tree_verification): L957-1004
- Lines (verify_stochastic): L1006-1074
Signature
@torch.compile
def tree_verification(
self,
input_ids: torch.Tensor,
output_ids: torch.Tensor,
tree_mask: torch.Tensor,
cache_lens: torch.Tensor,
non_leaf_len: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Greedy tree verification: accept tokens where target agrees with draft.
Args:
input_ids: Draft token predictions (batch_size, tree_seq_len)
output_ids: Target model predictions (batch_size, tree_seq_len)
tree_mask: Tree connectivity (batch_size, tree_seq_len, tree_seq_len)
cache_lens: Current KV cache lengths (batch_size,)
non_leaf_len: Number of non-leaf tree nodes
Returns:
acc_ids: Accepted token sequence (batch_size, max_depth)
acc_num: Number of accepted tokens per batch element
double_input: Flag indicating all candidates accepted (for double buffering)
"""
@torch.compile
def verify_stochastic(
self,
input_ids: torch.Tensor,
tree_mask: torch.Tensor,
p_llm: torch.Tensor,
p_ssm: torch.Tensor,
temperature: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Stochastic tree verification using rejection sampling.
Args:
input_ids: Draft token candidates (batch_size, tree_seq_len)
tree_mask: Tree connectivity mask
p_llm: Target model logits (batch_size, tree_seq_len, vocab_size)
p_ssm: Draft model logits (batch_size, tree_seq_len, vocab_size)
temperature: Sampling temperature for probability computation
Returns:
acc_ids: Accepted token sequence
acc_num: Number of accepted tokens
"""
Import
# These methods are bound to LlamaGlide/Qwen2Glide instances:
from longspec.test.llama_glide import LlamaGlide
# Access via: model.tree_verification(...) or model.verify_stochastic(...)
I/O Contract
Inputs (tree_verification)
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Draft predictions (B, tree_len) |
| output_ids | torch.Tensor | Yes | Target predictions (B, tree_len) — argmax of target logits |
| tree_mask | torch.Tensor | Yes | Tree connectivity (B, tree_len, tree_len) |
| cache_lens | torch.Tensor | Yes | KV cache lengths per batch (B,) |
| non_leaf_len | int | Yes | Count of non-leaf tree nodes |
Inputs (verify_stochastic)
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Draft token candidates (B, tree_len) |
| tree_mask | torch.Tensor | Yes | Tree connectivity mask (B, tree_len, tree_len) |
| p_llm | torch.Tensor | Yes | Target model logits (B, tree_len, vocab_size) |
| p_ssm | torch.Tensor | Yes | Draft model logits (B, tree_len, vocab_size) |
| temperature | float | Yes | Temperature for probability computation |
Outputs
| Name | Type | Description |
|---|---|---|
| acc_ids | torch.Tensor | Accepted token sequence (B, max_depth) |
| acc_num | torch.Tensor | Number of accepted tokens per batch (B,) |
| double_input | torch.Tensor | (greedy only) Flag for all-accepted case enabling double buffering |
Usage Examples
Greedy Verification (Inside tree_spec_generate)
# Called internally when temperature == 0.0:
acc_ids, acc_num, double_input = self.tree_verification(
input_ids=draft_tree_tokens, # Draft model's tree candidates
output_ids=target_predictions, # Target model's argmax per node
tree_mask=tree_mask, # Tree connectivity
cache_lens=cache_lens, # Current KV cache position
non_leaf_len=non_leaf_count, # Non-leaf node count
)
# acc_num tells how many tokens were "free" (accepted without target compute)
Stochastic Verification (Inside tree_spec_generate)
# Called internally when temperature > 0.0:
acc_ids, acc_num = self.verify_stochastic(
input_ids=draft_tree_tokens,
tree_mask=tree_mask,
p_llm=target_logits, # Full logit distribution from target
p_ssm=draft_logits, # Full logit distribution from draft
temperature=0.7, # Sampling temperature
)
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment