Implementation:Sail sg LongSpec Qwen2Glide Init
| Knowledge Sources | |
|---|---|
| Domains | Speculative_Decoding, Model_Architecture |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Concrete tool for constructing GLIDE draft models by attaching a trainable cross-attention decoder layer to a frozen Qwen2 or Llama target LLM.
Description
The Qwen2Glide and LlamaGlide classes extend their respective base model classes (Qwen2ForCausalLM / LlamaForCausalLM) with a single GLIDE decoder layer. During initialization, the target LLM is loaded from a HuggingFace checkpoint in float16, its parameters are frozen, and the GLIDE draft layer is either freshly initialized or loaded from a pre-trained checkpoint.
The classes share the same architecture pattern:
- Inherit from the corresponding ForCausalLM class via PretrainedModelParallelPreSplitMixin
- Load target model, extract backbone (model) and language modeling head (lm_head)
- Initialize GlideDecoderLayer with cross-attention, self-attention, and FFN sub-layers
- Freeze target parameters, enable gradients only for the draft layer
Usage
Import and instantiate when setting up GLIDE draft model training or loading a trained draft model for inference. The model is typically instantiated via Hydra configuration:
model = hydra.utils.call(cfg.model, cfg.model_name_or_path, state_dict=pretrain_state_dict)
Code Reference
Source Location
- Repository: LongSpec
- File (Qwen2): longspec/train/models/qwen2_glide.py
- Lines (Qwen2): L476-514
- File (Llama): longspec/train/models/llama_glide.py
- Lines (Llama): L473-510
Signature
class Qwen2Glide(PretrainedModelParallelPreSplitMixin, Qwen2ForCausalLM):
def __init__(
self,
config: Qwen2Config,
target_model_path: str,
glide_path: Optional[str] = None,
) -> None:
"""
Args:
config: Qwen2Config with model architecture parameters.
target_model_path: HuggingFace path or local path to frozen target LLM.
glide_path: Optional path to pre-trained GLIDE draft layer weights.
If None, initializes a fresh GlideDecoderLayer.
"""
class LlamaGlide(LlamaForCausalLM):
def __init__(
self,
config: LlamaConfig,
target_model_path: str,
glide_path: Optional[str] = None,
) -> None:
"""Same signature and behavior as Qwen2Glide for Llama architecture."""
Import
# For Qwen2-based models:
from longspec.train.models.qwen2_glide import Qwen2Glide
# For Llama-based models:
from longspec.train.models.llama_glide import LlamaGlide
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | Qwen2Config / LlamaConfig | Yes | Model architecture config loaded via AutoConfig.from_pretrained() |
| target_model_path | str | Yes | HuggingFace model ID or local path to the target LLM (e.g., "Qwen/QwQ-32B-Preview") |
| glide_path | Optional[str] | No | Path to pre-trained draft layer weights (.pth file from previous training stage) |
Outputs
| Name | Type | Description |
|---|---|---|
| model | Qwen2Glide / LlamaGlide | Initialized model with frozen target LLM and trainable GLIDE draft layer |
| model.model | Qwen2Model / LlamaModel | Frozen target LLM backbone (all parameters require_grad=False) |
| model.lm_head | nn.Linear | Frozen language modeling head shared between target and draft |
| model.draft_model | Qwen2GlideDecoderLayer / LlamaGlideDecoderLayer | Trainable single decoder layer (cross-attn + self-attn + FFN) |
Usage Examples
Direct Instantiation
from transformers import AutoConfig
from longspec.train.models.qwen2_glide import Qwen2Glide
# Load config from target model
config = AutoConfig.from_pretrained("Qwen/QwQ-32B-Preview")
# Fresh initialization (Stage 1 training)
model = Qwen2Glide(
config=config,
target_model_path="Qwen/QwQ-32B-Preview",
glide_path=None, # Fresh GLIDE layer
)
# Resuming from checkpoint (Stage 2+ training)
model = Qwen2Glide(
config=config,
target_model_path="Qwen/QwQ-32B-Preview",
glide_path="/path/to/stage1/draft_model_weights.pth",
)
Via Hydra Configuration
import hydra
from omegaconf import DictConfig
# In trainer_base_ds_mul_fs_tp.py:
model = hydra.utils.call(
cfg.model, # Points to Qwen2Glide.from_pretrained
cfg.model_name_or_path, # Target model path
state_dict=pretrain_state_dict # Optional state dict
)