Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sail sg LongSpec Tree Verification Accept

From Leeroopedia
Revision as of 13:49, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Sail_sg_LongSpec_Tree_Verification_Accept.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Knowledge Sources
Domains Speculative_Decoding, LLM_Inference
Last Updated 2026-02-14 05:00 GMT

Overview

Concrete tool for deciding which draft tokens to accept during tree speculative decoding, with greedy and stochastic verification modes and KV cache rearrangement.

Description

Two @torch.compile-decorated methods on LlamaGlide and Qwen2Glide:

  • tree_verification: Greedy acceptance — compares target argmax with draft tokens along tree paths, finds deepest accepted node, rearranges KV cache
  • verify_stochastic: Stochastic acceptance — uses rejection sampling with probability ratio min(1, P_target/P_draft), guaranteeing output distribution matches target

Both methods are JIT-compiled via @torch.compile for performance.

Usage

Called internally by tree_spec_generate and spec_generate after target model verification. Not called directly by users.

Code Reference

Source Location

  • Repository: LongSpec
  • File (Llama): longspec/test/llama_glide.py
  • Lines (tree_verification): L1128-1175
  • Lines (verify_stochastic): L1177-1245
  • File (Qwen2): longspec/test/qwen2_glide.py
  • Lines (tree_verification): L957-1004
  • Lines (verify_stochastic): L1006-1074

Signature

@torch.compile
def tree_verification(
    self,
    input_ids: torch.Tensor,
    output_ids: torch.Tensor,
    tree_mask: torch.Tensor,
    cache_lens: torch.Tensor,
    non_leaf_len: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Greedy tree verification: accept tokens where target agrees with draft.

    Args:
        input_ids: Draft token predictions (batch_size, tree_seq_len)
        output_ids: Target model predictions (batch_size, tree_seq_len)
        tree_mask: Tree connectivity (batch_size, tree_seq_len, tree_seq_len)
        cache_lens: Current KV cache lengths (batch_size,)
        non_leaf_len: Number of non-leaf tree nodes

    Returns:
        acc_ids: Accepted token sequence (batch_size, max_depth)
        acc_num: Number of accepted tokens per batch element
        double_input: Flag indicating all candidates accepted (for double buffering)
    """

@torch.compile
def verify_stochastic(
    self,
    input_ids: torch.Tensor,
    tree_mask: torch.Tensor,
    p_llm: torch.Tensor,
    p_ssm: torch.Tensor,
    temperature: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Stochastic tree verification using rejection sampling.

    Args:
        input_ids: Draft token candidates (batch_size, tree_seq_len)
        tree_mask: Tree connectivity mask
        p_llm: Target model logits (batch_size, tree_seq_len, vocab_size)
        p_ssm: Draft model logits (batch_size, tree_seq_len, vocab_size)
        temperature: Sampling temperature for probability computation

    Returns:
        acc_ids: Accepted token sequence
        acc_num: Number of accepted tokens
    """

Import

# These methods are bound to LlamaGlide/Qwen2Glide instances:
from longspec.test.llama_glide import LlamaGlide
# Access via: model.tree_verification(...) or model.verify_stochastic(...)

I/O Contract

Inputs (tree_verification)

Name Type Required Description
input_ids torch.Tensor Yes Draft predictions (B, tree_len)
output_ids torch.Tensor Yes Target predictions (B, tree_len) — argmax of target logits
tree_mask torch.Tensor Yes Tree connectivity (B, tree_len, tree_len)
cache_lens torch.Tensor Yes KV cache lengths per batch (B,)
non_leaf_len int Yes Count of non-leaf tree nodes

Inputs (verify_stochastic)

Name Type Required Description
input_ids torch.Tensor Yes Draft token candidates (B, tree_len)
tree_mask torch.Tensor Yes Tree connectivity mask (B, tree_len, tree_len)
p_llm torch.Tensor Yes Target model logits (B, tree_len, vocab_size)
p_ssm torch.Tensor Yes Draft model logits (B, tree_len, vocab_size)
temperature float Yes Temperature for probability computation

Outputs

Name Type Description
acc_ids torch.Tensor Accepted token sequence (B, max_depth)
acc_num torch.Tensor Number of accepted tokens per batch (B,)
double_input torch.Tensor (greedy only) Flag for all-accepted case enabling double buffering

Usage Examples

Greedy Verification (Inside tree_spec_generate)

# Called internally when temperature == 0.0:
acc_ids, acc_num, double_input = self.tree_verification(
    input_ids=draft_tree_tokens,     # Draft model's tree candidates
    output_ids=target_predictions,    # Target model's argmax per node
    tree_mask=tree_mask,              # Tree connectivity
    cache_lens=cache_lens,            # Current KV cache position
    non_leaf_len=non_leaf_count,      # Non-leaf node count
)
# acc_num tells how many tokens were "free" (accepted without target compute)

Stochastic Verification (Inside tree_spec_generate)

# Called internally when temperature > 0.0:
acc_ids, acc_num = self.verify_stochastic(
    input_ids=draft_tree_tokens,
    tree_mask=tree_mask,
    p_llm=target_logits,     # Full logit distribution from target
    p_ssm=draft_logits,      # Full logit distribution from draft
    temperature=0.7,          # Sampling temperature
)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment