Implementation:Sail sg LongSpec Tree Spec Generate
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Speculative_Decoding, LLM_Inference |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Concrete tool for accelerated LLM inference using tree-structured speculative decoding with GLIDE draft model candidate generation and parallel target LLM verification.
Description
tree_spec_generate is the core inference method on LlamaGlide and Qwen2Glide models. It generates text by:
- Building a tree of candidate tokens using the GLIDE draft model
- Verifying the entire tree in one target LLM forward pass
- Accepting the longest valid path and continuing from there
The method also includes spec_generate (sequential speculation) and vanilla_generate (baseline) as alternative generation modes.
Usage
Use when performing inference with a trained GLIDE model. Select the method based on performance requirements:
- tree_spec_generate: Best throughput for long contexts (default)
- spec_generate: Simpler sequential speculation for shorter contexts
- vanilla_generate: Baseline for comparison benchmarks
Code Reference
Source Location
- Repository: LongSpec
- File (Llama): longspec/test/llama_glide.py
- Lines (tree_spec_generate): L915-1126
- Lines (spec_generate): L621-774
- Lines (vanilla_generate): L552-585
- File (Qwen2): longspec/test/qwen2_glide.py
- Lines (tree_spec_generate): L744-955
- Lines (spec_generate): L589-742
Signature
def tree_spec_generate(
self,
input_ids: torch.Tensor,
prompt_length: int,
tree_shape: Optional[List[int]] = None,
max_gen_len: int = 64,
eos_id: int = 151645,
temperature: float = 0.0,
) -> Tuple[torch.Tensor, int, int, float, torch.Tensor]:
"""
Generate text using tree-structured speculative decoding.
Args:
input_ids: Tokenized input (batch_size, seq_len)
prompt_length: Length of input prompt
tree_shape: Branching factors per level (default: [4, 16, 16, 16, 16])
max_gen_len: Maximum tokens to generate (default: 64)
eos_id: End-of-sequence token ID
temperature: 0.0 for greedy, >0 for stochastic verification
Returns:
output_ids: Generated token sequence (batch_size, max_gen_len)
count: Number of draft tokens accepted
num: Total target model forward tokens processed
elapsed_time: Wall-clock generation time (seconds)
spec_mask: Speculative decoding mask for analysis
"""
def spec_generate(
self,
input_ids: torch.Tensor,
prompt_length: int,
gamma: int = 4,
max_gen_len: int = 64,
eos_id: int = 151645,
temperature: float = 0.0,
) -> Tuple[torch.Tensor, int, int, float, torch.Tensor]:
"""
Sequential speculative decoding (gamma tokens per step).
Args:
input_ids: Tokenized input
prompt_length: Prompt length
gamma: Number of speculative tokens per step (default: 4)
max_gen_len: Max generation length
eos_id: EOS token ID
temperature: Sampling temperature
Returns:
Same tuple as tree_spec_generate
"""
def vanilla_generate(
self,
input_ids: torch.Tensor,
prompt_length: int,
max_gen_len: int = 64,
eos_id: int = 151645,
) -> Tuple[torch.Tensor, int, float]:
"""
Baseline autoregressive generation (no speculation).
Args:
input_ids: Tokenized input
prompt_length: Prompt length
max_gen_len: Max generation length
eos_id: EOS token ID
Returns:
output_ids: Generated tokens
num_tokens: Total tokens processed
elapsed_time: Generation time
"""
Import
# Llama variant:
from longspec.test.llama_glide import LlamaGlide
# Qwen2 variant:
from longspec.test.qwen2_glide import Qwen2Glide
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Tokenized prompt (batch_size, seq_len) on CUDA |
| prompt_length | int | Yes | Number of prompt tokens (generation starts after this position) |
| tree_shape | List[int] | No | Branching factors per tree level (default: [4, 16, 16, 16, 16]) |
| max_gen_len | int | No | Maximum tokens to generate (default: 64) |
| eos_id | int | No | End-of-sequence token ID (default: 151645 for Qwen2) |
| temperature | float | No | 0.0 = greedy, >0.0 = stochastic verification (default: 0.0) |
Outputs
| Name | Type | Description |
|---|---|---|
| output_ids | torch.Tensor | Generated token IDs (batch_size, max_gen_len) |
| count | int | Number of draft tokens accepted by target model |
| num | int | Total tokens processed by target model (for throughput calculation) |
| elapsed_time | float | Wall-clock time in seconds |
| spec_mask | torch.Tensor | Binary mask showing which positions used speculation vs. target-only |
Usage Examples
Tree Speculative Decoding
from transformers import AutoConfig, AutoTokenizer
from longspec.test.llama_glide import LlamaGlide
# Load model
config = AutoConfig.from_pretrained("lmsys/vicuna-7b-v1.5")
model = LlamaGlide(config, "lmsys/vicuna-7b-v1.5", "sail/longspec-vicuna-7b-v1.5")
tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
# Tokenize input
prompt = "Summarize the following document: ..."
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
# Tree speculative decoding
output_ids, count, num, elapsed, mask = model.tree_spec_generate(
input_ids=input_ids,
prompt_length=input_ids.shape[1],
tree_shape=[4, 16, 16, 16, 16],
max_gen_len=1024,
temperature=0.0,
)
# Decode output
text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
throughput = num / elapsed
print(f"Generated {num} tokens at {throughput:.1f} tok/s")
print(f"Acceptance rate: {count}/{num} = {count/num:.2%}")
Sequential Speculative Decoding
# Sequential mode with gamma=4 draft tokens per step
output_ids, count, num, elapsed, mask = model.spec_generate(
input_ids=input_ids,
prompt_length=input_ids.shape[1],
gamma=4,
max_gen_len=1024,
)
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment