Implementation:Sail sg LongSpec Save Model Extract
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Checkpointing, Training |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Concrete tool for saving trained GLIDE model checkpoints, extracting draft model weights with prefix stripping, and managing rotating DeepSpeed checkpoint directories.
Description
Two functions in trainer_base_ds_mul_fs_tp.py handle checkpoint extraction:
- extract_and_rename(): Filters a state dict for keys matching a prefix, strips the prefix, and returns an OrderedDict of draft-only weights
- save_model(): Orchestrates the full save pipeline — DeepSpeed checkpoint saving, ZeRO-3 state consolidation, draft weight extraction, and tokenizer/config saving
Usage
Called automatically by the training loop at save_steps intervals and at training completion. Users do not call these directly.
Code Reference
Source Location
- Repository: LongSpec
- File: longspec/train/trainer_base_ds_mul_fs_tp.py
- Lines: L49-113
Signature
def extract_and_rename(
state_dict: Dict,
prefix: str = "draft_model.",
) -> OrderedDict:
"""
Extract and rename state dict entries by removing prefix.
Args:
state_dict: Complete model state dictionary
prefix: Key prefix to match and strip (default: "draft_model.")
Returns:
OrderedDict with matching keys, prefix removed
"""
def save_model(
model: Union[deepspeed.DeepSpeedEngine, deepspeed.PipelineEngine],
cfg: DictConfig,
output_dir: str,
tokenizer: PreTrainedTokenizer = None,
state_dict: Dict = None,
) -> None:
"""
Save model checkpoint with draft weight extraction.
Args:
model: DeepSpeed engine wrapping the GLIDE model
cfg: Hydra config with output_dir, ZeRO stage, save_ds_state flag
output_dir: Directory to save checkpoint files
tokenizer: Optional tokenizer to save alongside
state_dict: Optional pre-consolidated state dict
Side Effects:
- Saves draft_model_weights.pth to output_dir
- Saves DeepSpeed checkpoint to last_ds/last_2ds (rotating)
- Saves tokenizer and config to output_dir
"""
Import
from longspec.train.trainer_base_ds_mul_fs_tp import save_model, extract_and_rename
I/O Contract
Inputs (save_model)
| Name | Type | Required | Description |
|---|---|---|---|
| model | DeepSpeedEngine | Yes | DeepSpeed-wrapped GLIDE model with target + draft parameters |
| cfg | DictConfig | Yes | Config with output_dir, ZeRO stage info, save_ds_state flag |
| output_dir | str | Yes | Directory path for saving checkpoint files |
| tokenizer | PreTrainedTokenizer | No | Tokenizer to save (for model deployment) |
| state_dict | Dict | No | Pre-consolidated state dict (if None, extracted from model) |
Inputs (extract_and_rename)
| Name | Type | Required | Description |
|---|---|---|---|
| state_dict | Dict | Yes | Full model state dict with prefixed keys |
| prefix | str | No | Key prefix to match (default: "draft_model.") |
Outputs
| Name | Type | Description |
|---|---|---|
| draft_model_weights.pth | File | PyTorch state dict containing only draft layer parameters (prefix stripped) |
| last_ds/ | Directory | Most recent DeepSpeed checkpoint (full engine state) |
| last_2ds/ | Directory | Previous DeepSpeed checkpoint (for rollback) |
| tokenizer files | Files | Tokenizer config and vocabulary saved to output_dir |
Usage Examples
Extract Draft Weights from State Dict
from collections import OrderedDict
# Given a full model state dict:
full_state = {
"model.layers.0.self_attn.q_proj.weight": ..., # Target (ignored)
"draft_model.self_attn.q_proj.weight": ..., # Draft (extracted)
"draft_model.cross_attn.q_proj.weight": ..., # Draft (extracted)
"draft_model.mlp.gate_proj.weight": ..., # Draft (extracted)
}
draft_state = extract_and_rename(full_state, prefix="draft_model.")
# Result: OrderedDict({
# "self_attn.q_proj.weight": ...,
# "cross_attn.q_proj.weight": ...,
# "mlp.gate_proj.weight": ...,
# })
Save Pipeline in Training Loop
# Called inside train() at save_steps intervals:
if global_step % cfg.save_steps == 0:
save_model(
model=model, # DeepSpeed engine
cfg=cfg, # Full Hydra config
output_dir=f"{cfg.output_dir}/step_{global_step}",
tokenizer=tokenizer,
)
# Produces:
# {output_dir}/step_{global_step}/draft_model_weights.pth
# {output_dir}/last_ds/ (rotating DS checkpoint)
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment