Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sail sg LongSpec Qwen2Glide Init

From Leeroopedia
Knowledge Sources
Domains Speculative_Decoding, Model_Architecture
Last Updated 2026-02-14 05:00 GMT

Overview

Concrete tool for constructing GLIDE draft models by attaching a trainable cross-attention decoder layer to a frozen Qwen2 or Llama target LLM.

Description

The Qwen2Glide and LlamaGlide classes extend their respective base model classes (Qwen2ForCausalLM / LlamaForCausalLM) with a single GLIDE decoder layer. During initialization, the target LLM is loaded from a HuggingFace checkpoint in float16, its parameters are frozen, and the GLIDE draft layer is either freshly initialized or loaded from a pre-trained checkpoint.

The classes share the same architecture pattern:

  • Inherit from the corresponding ForCausalLM class via PretrainedModelParallelPreSplitMixin
  • Load target model, extract backbone (model) and language modeling head (lm_head)
  • Initialize GlideDecoderLayer with cross-attention, self-attention, and FFN sub-layers
  • Freeze target parameters, enable gradients only for the draft layer

Usage

Import and instantiate when setting up GLIDE draft model training or loading a trained draft model for inference. The model is typically instantiated via Hydra configuration:

model = hydra.utils.call(cfg.model, cfg.model_name_or_path, state_dict=pretrain_state_dict)

Code Reference

Source Location

  • Repository: LongSpec
  • File (Qwen2): longspec/train/models/qwen2_glide.py
  • Lines (Qwen2): L476-514
  • File (Llama): longspec/train/models/llama_glide.py
  • Lines (Llama): L473-510

Signature

class Qwen2Glide(PretrainedModelParallelPreSplitMixin, Qwen2ForCausalLM):
    def __init__(
        self,
        config: Qwen2Config,
        target_model_path: str,
        glide_path: Optional[str] = None,
    ) -> None:
        """
        Args:
            config: Qwen2Config with model architecture parameters.
            target_model_path: HuggingFace path or local path to frozen target LLM.
            glide_path: Optional path to pre-trained GLIDE draft layer weights.
                        If None, initializes a fresh GlideDecoderLayer.
        """

class LlamaGlide(LlamaForCausalLM):
    def __init__(
        self,
        config: LlamaConfig,
        target_model_path: str,
        glide_path: Optional[str] = None,
    ) -> None:
        """Same signature and behavior as Qwen2Glide for Llama architecture."""

Import

# For Qwen2-based models:
from longspec.train.models.qwen2_glide import Qwen2Glide

# For Llama-based models:
from longspec.train.models.llama_glide import LlamaGlide

I/O Contract

Inputs

Name Type Required Description
config Qwen2Config / LlamaConfig Yes Model architecture config loaded via AutoConfig.from_pretrained()
target_model_path str Yes HuggingFace model ID or local path to the target LLM (e.g., "Qwen/QwQ-32B-Preview")
glide_path Optional[str] No Path to pre-trained draft layer weights (.pth file from previous training stage)

Outputs

Name Type Description
model Qwen2Glide / LlamaGlide Initialized model with frozen target LLM and trainable GLIDE draft layer
model.model Qwen2Model / LlamaModel Frozen target LLM backbone (all parameters require_grad=False)
model.lm_head nn.Linear Frozen language modeling head shared between target and draft
model.draft_model Qwen2GlideDecoderLayer / LlamaGlideDecoderLayer Trainable single decoder layer (cross-attn + self-attn + FFN)

Usage Examples

Direct Instantiation

from transformers import AutoConfig
from longspec.train.models.qwen2_glide import Qwen2Glide

# Load config from target model
config = AutoConfig.from_pretrained("Qwen/QwQ-32B-Preview")

# Fresh initialization (Stage 1 training)
model = Qwen2Glide(
    config=config,
    target_model_path="Qwen/QwQ-32B-Preview",
    glide_path=None,  # Fresh GLIDE layer
)

# Resuming from checkpoint (Stage 2+ training)
model = Qwen2Glide(
    config=config,
    target_model_path="Qwen/QwQ-32B-Preview",
    glide_path="/path/to/stage1/draft_model_weights.pth",
)

Via Hydra Configuration

import hydra
from omegaconf import DictConfig

# In trainer_base_ds_mul_fs_tp.py:
model = hydra.utils.call(
    cfg.model,                    # Points to Qwen2Glide.from_pretrained
    cfg.model_name_or_path,       # Target model path
    state_dict=pretrain_state_dict  # Optional state dict
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment