Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Sail sg LongSpec Draft Checkpoint Extraction

From Leeroopedia
Knowledge Sources
Domains Checkpointing, Training, Distributed_Computing
Last Updated 2026-02-14 05:00 GMT

Overview

Principle for extracting and isolating trained draft model weights from a combined target-plus-draft model checkpoint, handling DeepSpeed ZeRO state consolidation.

Description

During GLIDE training, the model contains both the frozen target LLM and the trainable draft layer. The full model checkpoint (managed by DeepSpeed) includes weights for both components plus optimizer states. However, for inference and multi-stage training, only the draft model weights are needed as a standalone file.

Draft Checkpoint Extraction handles:

  • ZeRO-3 state consolidation: When using ZeRO Stage 3, parameters are sharded across GPUs and must be consolidated before extraction
  • Prefix-based filtering: Draft model parameters are identified by the "draft_model." key prefix in the state dict
  • Key renaming: The "draft_model." prefix is stripped from keys so the weights can be loaded directly into a GlideDecoderLayer
  • Rotating checkpoint management: DeepSpeed checkpoints are maintained in a rotating window (last/last_2) to save disk space

This extraction enables the multi-stage training pipeline where Stage N+1 loads the draft weights from Stage N's output.

Usage

Apply this principle at the end of each training stage to produce a portable draft_model_weights.pth file. This file is referenced by the next stage's config as draft_model_name_or_path.

Theoretical Basis

State Dict Filtering:

# Abstract extraction logic (not actual implementation)
full_state = model.state_dict()  # Contains target + draft + optimizer

# Filter for draft model keys and strip prefix
draft_state = {}
for key, value in full_state.items():
    if key.startswith("draft_model."):
        draft_state[key.replace("draft_model.", "")] = value

torch.save(draft_state, "draft_model_weights.pth")

ZeRO-3 Consolidation:

Under ZeRO Stage 3, each GPU holds only a shard of each parameter. Before extraction, the shards must be gathered:

  1. All ranks communicate to reconstruct full parameters
  2. Only rank 0 performs the actual save
  3. A distributed barrier synchronizes all ranks after saving

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment