Workflow:Sail sg LongSpec Speculative Decoding Inference
| Knowledge Sources | |
|---|---|
| Domains | LLMs, Speculative_Decoding, Inference_Optimization |
| Last Updated | 2025-02-01 00:00 GMT |
Overview
End-to-end inference pipeline for accelerating large language model generation using LongSpec's GLIDE draft model with tree-structured speculative decoding and hybrid attention verification.
Description
This workflow performs accelerated inference by pairing a target LLM (Llama or Qwen2 family) with a trained GLIDE draft model. The draft model proposes candidate tokens using tree-structured beam search, and the target model verifies them in parallel via a custom Triton tree attention kernel. The process is lossless: the output distribution matches standard autoregressive decoding exactly.
The system supports multiple inference methods:
- Vanilla: Standard autoregressive decoding (baseline)
- Sequential (seq): GLIDE draft with greedy sequential speculation
- Tree: GLIDE draft with tree-structured beam search expansion (recommended)
- MagicDec: Integration with the MagicDec framework
The tree method delivers the highest throughput by generating multiple candidate continuations per step, then verifying the entire tree in a single target model forward pass using the Triton tree attention kernel combined with prefix attention merging.
Usage
Execute this workflow when you have a trained LongSpec GLIDE draft model and need to generate text from the target LLM with reduced latency. This is particularly effective for long-context scenarios (16k-262k tokens) where the constant-size KV cache of the GLIDE draft model provides significant memory savings over traditional draft models. An 80GB GPU is recommended for optimal performance.
Execution Steps
Step 1: Model_Loading
Load both the target LLM and its corresponding GLIDE draft model. The target model configuration is fetched from Hugging Face (e.g., Qwen/QwQ-32B-Preview or gradientai/Llama-3-8B-Instruct-262k). The GLIDE model class (LlamaGlide or Qwen2Glide) initializes the combined architecture, loading the target model weights and the trained draft model weights from the LongSpec Hugging Face repository (e.g., sail/longspec-QwQ-32B-Preview).
Key considerations:
- Configure pad_token_id and eos_token_id appropriately for each model family
- Both models are loaded onto GPU; a single 80GB GPU is recommended
- The tokenizer is loaded from the target model path
- The draft model weights are loaded from the corresponding sail/longspec-* checkpoint
Step 2: Input_Preparation
Format the input prompt according to the model's expected chat template. For Llama models, prompts use the system/user/assistant format. For Qwen2 models, prompts use the <|im_start|>/<|im_end|> format. The input is tokenized and transferred to GPU. An attention mask is computed to determine the effective prompt length.
Key considerations:
- Different model families require different prompt templates
- Input token length must be within the model's context window (minus generation headroom)
- Prompts shorter than 1200 tokens are typically filtered out for benchmarking purposes
- The tokenizer padding side should be set to right
Step 3: Inference_Method_Selection
Choose the speculative decoding method based on the use case. The tree method is the default and recommended choice, offering the best throughput. Configure the tree shape parameter (e.g., [4, 16, 16, 16, 16]) which controls the branching factor at each depth level of the speculation tree. The gamma parameter controls the draft sequence length for the sequential method.
Key considerations:
- Tree method: best throughput, uses tree_shape to define branch structure (default [4, 16, 16, 16, 16])
- Sequential method: simpler, uses gamma parameter for draft length
- Temperature=0 for greedy decoding; temperature>0 for sampling
- A warm-up generation pass is performed before benchmarking to initialize CUDA kernels
Step 4: Draft_Generation
The GLIDE draft model generates candidate tokens using its cross-attention mechanism over the target LLM's KV cache. For tree speculation, beam search expands multiple paths simultaneously, producing a tree of candidate token sequences. The draft model maintains only a constant-size sliding-window KV cache regardless of context length, enabling efficient drafting even for very long inputs.
Key considerations:
- Cross-attention accesses the target LLM's full KV cache without storing a copy
- Sliding-window self-attention provides local context awareness
- Tree expansion uses top-k sampling at each level to create candidate branches
- The tree structure is defined by the tree_shape parameter
Step 5: Tree_Verification
The target LLM verifies the entire candidate tree in a single forward pass. This uses a custom Triton tree attention kernel that handles the non-causal attention masks required by the tree structure. The kernel combines prefix attention (for the KV cache up to the prompt) with tree attention (for the candidate tokens) using sigmoid-weighted log-sum-exp merging, enabling efficient parallel verification.
Key considerations:
- The Triton kernel handles tree-shaped attention masks that Flash Attention cannot support
- Prefix attention and tree attention are merged via a learned sigmoid-weighted combination
- GQA (Grouped Query Attention) is supported natively in the kernel
- Online softmax is used for numerical stability
Step 6: Token_Acceptance
Compare the target model's verified probabilities against the draft model's proposals. For greedy decoding (temperature=0), tokens are accepted if they match exactly. For sampling, the standard speculative sampling rejection criterion is applied. Accepted tokens are appended to the output, and the process resumes from the last accepted position. The acceptance rate (tokens accepted per verification round) measures drafting quality.
Key considerations:
- Lossless guarantee: output distribution is identical to standard autoregressive decoding
- Higher acceptance rates indicate better draft model quality
- The process loops Steps 4-6 until max_gen_len tokens are produced or EOS is reached
- Throughput is measured as total tokens (accepted + verified) per second