Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sail sg LongSpec Save Model Extract

From Leeroopedia
Knowledge Sources
Domains Checkpointing, Training
Last Updated 2026-02-14 05:00 GMT

Overview

Concrete tool for saving trained GLIDE model checkpoints, extracting draft model weights with prefix stripping, and managing rotating DeepSpeed checkpoint directories.

Description

Two functions in trainer_base_ds_mul_fs_tp.py handle checkpoint extraction:

  • extract_and_rename(): Filters a state dict for keys matching a prefix, strips the prefix, and returns an OrderedDict of draft-only weights
  • save_model(): Orchestrates the full save pipeline — DeepSpeed checkpoint saving, ZeRO-3 state consolidation, draft weight extraction, and tokenizer/config saving

Usage

Called automatically by the training loop at save_steps intervals and at training completion. Users do not call these directly.

Code Reference

Source Location

  • Repository: LongSpec
  • File: longspec/train/trainer_base_ds_mul_fs_tp.py
  • Lines: L49-113

Signature

def extract_and_rename(
    state_dict: Dict,
    prefix: str = "draft_model.",
) -> OrderedDict:
    """
    Extract and rename state dict entries by removing prefix.

    Args:
        state_dict: Complete model state dictionary
        prefix: Key prefix to match and strip (default: "draft_model.")

    Returns:
        OrderedDict with matching keys, prefix removed
    """

def save_model(
    model: Union[deepspeed.DeepSpeedEngine, deepspeed.PipelineEngine],
    cfg: DictConfig,
    output_dir: str,
    tokenizer: PreTrainedTokenizer = None,
    state_dict: Dict = None,
) -> None:
    """
    Save model checkpoint with draft weight extraction.

    Args:
        model: DeepSpeed engine wrapping the GLIDE model
        cfg: Hydra config with output_dir, ZeRO stage, save_ds_state flag
        output_dir: Directory to save checkpoint files
        tokenizer: Optional tokenizer to save alongside
        state_dict: Optional pre-consolidated state dict

    Side Effects:
        - Saves draft_model_weights.pth to output_dir
        - Saves DeepSpeed checkpoint to last_ds/last_2ds (rotating)
        - Saves tokenizer and config to output_dir
    """

Import

from longspec.train.trainer_base_ds_mul_fs_tp import save_model, extract_and_rename

I/O Contract

Inputs (save_model)

Name Type Required Description
model DeepSpeedEngine Yes DeepSpeed-wrapped GLIDE model with target + draft parameters
cfg DictConfig Yes Config with output_dir, ZeRO stage info, save_ds_state flag
output_dir str Yes Directory path for saving checkpoint files
tokenizer PreTrainedTokenizer No Tokenizer to save (for model deployment)
state_dict Dict No Pre-consolidated state dict (if None, extracted from model)

Inputs (extract_and_rename)

Name Type Required Description
state_dict Dict Yes Full model state dict with prefixed keys
prefix str No Key prefix to match (default: "draft_model.")

Outputs

Name Type Description
draft_model_weights.pth File PyTorch state dict containing only draft layer parameters (prefix stripped)
last_ds/ Directory Most recent DeepSpeed checkpoint (full engine state)
last_2ds/ Directory Previous DeepSpeed checkpoint (for rollback)
tokenizer files Files Tokenizer config and vocabulary saved to output_dir

Usage Examples

Extract Draft Weights from State Dict

from collections import OrderedDict

# Given a full model state dict:
full_state = {
    "model.layers.0.self_attn.q_proj.weight": ...,  # Target (ignored)
    "draft_model.self_attn.q_proj.weight": ...,       # Draft (extracted)
    "draft_model.cross_attn.q_proj.weight": ...,      # Draft (extracted)
    "draft_model.mlp.gate_proj.weight": ...,           # Draft (extracted)
}

draft_state = extract_and_rename(full_state, prefix="draft_model.")
# Result: OrderedDict({
#     "self_attn.q_proj.weight": ...,
#     "cross_attn.q_proj.weight": ...,
#     "mlp.gate_proj.weight": ...,
# })

Save Pipeline in Training Loop

# Called inside train() at save_steps intervals:
if global_step % cfg.save_steps == 0:
    save_model(
        model=model,                           # DeepSpeed engine
        cfg=cfg,                               # Full Hydra config
        output_dir=f"{cfg.output_dir}/step_{global_step}",
        tokenizer=tokenizer,
    )
    # Produces:
    #   {output_dir}/step_{global_step}/draft_model_weights.pth
    #   {output_dir}/last_ds/          (rotating DS checkpoint)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment