Principle:Sail sg LongSpec Draft Checkpoint Extraction
| Knowledge Sources | |
|---|---|
| Domains | Checkpointing, Training, Distributed_Computing |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Principle for extracting and isolating trained draft model weights from a combined target-plus-draft model checkpoint, handling DeepSpeed ZeRO state consolidation.
Description
During GLIDE training, the model contains both the frozen target LLM and the trainable draft layer. The full model checkpoint (managed by DeepSpeed) includes weights for both components plus optimizer states. However, for inference and multi-stage training, only the draft model weights are needed as a standalone file.
Draft Checkpoint Extraction handles:
- ZeRO-3 state consolidation: When using ZeRO Stage 3, parameters are sharded across GPUs and must be consolidated before extraction
- Prefix-based filtering: Draft model parameters are identified by the "draft_model." key prefix in the state dict
- Key renaming: The "draft_model." prefix is stripped from keys so the weights can be loaded directly into a GlideDecoderLayer
- Rotating checkpoint management: DeepSpeed checkpoints are maintained in a rotating window (last/last_2) to save disk space
This extraction enables the multi-stage training pipeline where Stage N+1 loads the draft weights from Stage N's output.
Usage
Apply this principle at the end of each training stage to produce a portable draft_model_weights.pth file. This file is referenced by the next stage's config as draft_model_name_or_path.
Theoretical Basis
State Dict Filtering:
# Abstract extraction logic (not actual implementation)
full_state = model.state_dict() # Contains target + draft + optimizer
# Filter for draft model keys and strip prefix
draft_state = {}
for key, value in full_state.items():
if key.startswith("draft_model."):
draft_state[key.replace("draft_model.", "")] = value
torch.save(draft_state, "draft_model_weights.pth")
ZeRO-3 Consolidation:
Under ZeRO Stage 3, each GPU holds only a shard of each parameter. Before extraction, the shards must be gathered:
- All ranks communicate to reconstruct full parameters
- Only rank 0 performs the actual save
- A distributed barrier synchronizes all ranks after saving