Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sail sg LongSpec Tree Spec Generate

From Leeroopedia
Knowledge Sources
Domains Speculative_Decoding, LLM_Inference
Last Updated 2026-02-14 05:00 GMT

Overview

Concrete tool for accelerated LLM inference using tree-structured speculative decoding with GLIDE draft model candidate generation and parallel target LLM verification.

Description

tree_spec_generate is the core inference method on LlamaGlide and Qwen2Glide models. It generates text by:

  1. Building a tree of candidate tokens using the GLIDE draft model
  2. Verifying the entire tree in one target LLM forward pass
  3. Accepting the longest valid path and continuing from there

The method also includes spec_generate (sequential speculation) and vanilla_generate (baseline) as alternative generation modes.

Usage

Use when performing inference with a trained GLIDE model. Select the method based on performance requirements:

  • tree_spec_generate: Best throughput for long contexts (default)
  • spec_generate: Simpler sequential speculation for shorter contexts
  • vanilla_generate: Baseline for comparison benchmarks

Code Reference

Source Location

  • Repository: LongSpec
  • File (Llama): longspec/test/llama_glide.py
  • Lines (tree_spec_generate): L915-1126
  • Lines (spec_generate): L621-774
  • Lines (vanilla_generate): L552-585
  • File (Qwen2): longspec/test/qwen2_glide.py
  • Lines (tree_spec_generate): L744-955
  • Lines (spec_generate): L589-742

Signature

def tree_spec_generate(
    self,
    input_ids: torch.Tensor,
    prompt_length: int,
    tree_shape: Optional[List[int]] = None,
    max_gen_len: int = 64,
    eos_id: int = 151645,
    temperature: float = 0.0,
) -> Tuple[torch.Tensor, int, int, float, torch.Tensor]:
    """
    Generate text using tree-structured speculative decoding.

    Args:
        input_ids: Tokenized input (batch_size, seq_len)
        prompt_length: Length of input prompt
        tree_shape: Branching factors per level (default: [4, 16, 16, 16, 16])
        max_gen_len: Maximum tokens to generate (default: 64)
        eos_id: End-of-sequence token ID
        temperature: 0.0 for greedy, >0 for stochastic verification

    Returns:
        output_ids: Generated token sequence (batch_size, max_gen_len)
        count: Number of draft tokens accepted
        num: Total target model forward tokens processed
        elapsed_time: Wall-clock generation time (seconds)
        spec_mask: Speculative decoding mask for analysis
    """

def spec_generate(
    self,
    input_ids: torch.Tensor,
    prompt_length: int,
    gamma: int = 4,
    max_gen_len: int = 64,
    eos_id: int = 151645,
    temperature: float = 0.0,
) -> Tuple[torch.Tensor, int, int, float, torch.Tensor]:
    """
    Sequential speculative decoding (gamma tokens per step).

    Args:
        input_ids: Tokenized input
        prompt_length: Prompt length
        gamma: Number of speculative tokens per step (default: 4)
        max_gen_len: Max generation length
        eos_id: EOS token ID
        temperature: Sampling temperature

    Returns:
        Same tuple as tree_spec_generate
    """

def vanilla_generate(
    self,
    input_ids: torch.Tensor,
    prompt_length: int,
    max_gen_len: int = 64,
    eos_id: int = 151645,
) -> Tuple[torch.Tensor, int, float]:
    """
    Baseline autoregressive generation (no speculation).

    Args:
        input_ids: Tokenized input
        prompt_length: Prompt length
        max_gen_len: Max generation length
        eos_id: EOS token ID

    Returns:
        output_ids: Generated tokens
        num_tokens: Total tokens processed
        elapsed_time: Generation time
    """

Import

# Llama variant:
from longspec.test.llama_glide import LlamaGlide

# Qwen2 variant:
from longspec.test.qwen2_glide import Qwen2Glide

I/O Contract

Inputs

Name Type Required Description
input_ids torch.Tensor Yes Tokenized prompt (batch_size, seq_len) on CUDA
prompt_length int Yes Number of prompt tokens (generation starts after this position)
tree_shape List[int] No Branching factors per tree level (default: [4, 16, 16, 16, 16])
max_gen_len int No Maximum tokens to generate (default: 64)
eos_id int No End-of-sequence token ID (default: 151645 for Qwen2)
temperature float No 0.0 = greedy, >0.0 = stochastic verification (default: 0.0)

Outputs

Name Type Description
output_ids torch.Tensor Generated token IDs (batch_size, max_gen_len)
count int Number of draft tokens accepted by target model
num int Total tokens processed by target model (for throughput calculation)
elapsed_time float Wall-clock time in seconds
spec_mask torch.Tensor Binary mask showing which positions used speculation vs. target-only

Usage Examples

Tree Speculative Decoding

from transformers import AutoConfig, AutoTokenizer
from longspec.test.llama_glide import LlamaGlide

# Load model
config = AutoConfig.from_pretrained("lmsys/vicuna-7b-v1.5")
model = LlamaGlide(config, "lmsys/vicuna-7b-v1.5", "sail/longspec-vicuna-7b-v1.5")
tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")

# Tokenize input
prompt = "Summarize the following document: ..."
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

# Tree speculative decoding
output_ids, count, num, elapsed, mask = model.tree_spec_generate(
    input_ids=input_ids,
    prompt_length=input_ids.shape[1],
    tree_shape=[4, 16, 16, 16, 16],
    max_gen_len=1024,
    temperature=0.0,
)

# Decode output
text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
throughput = num / elapsed
print(f"Generated {num} tokens at {throughput:.1f} tok/s")
print(f"Acceptance rate: {count}/{num} = {count/num:.2%}")

Sequential Speculative Decoding

# Sequential mode with gamma=4 draft tokens per step
output_ids, count, num, elapsed, mask = model.spec_generate(
    input_ids=input_ids,
    prompt_length=input_ids.shape[1],
    gamma=4,
    max_gen_len=1024,
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment