Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Sail sg LongSpec Multi Stage Training

From Leeroopedia
Knowledge Sources
Domains Training, Curriculum_Learning, Long_Context
Last Updated 2026-02-14 05:00 GMT

Overview

Training strategy that progressively builds draft model capability across multiple stages, increasing context length and specializing for long chain-of-thought reasoning.

Description

Multi-Stage Training for GLIDE draft models follows a curriculum learning approach where each stage builds upon the previous one:

  • Stage 1 (Base Training): Train the GLIDE draft layer from scratch on general-purpose text (SlimPajama-6B) with short contexts (1024 tokens). Uses ZeRO-1 optimization and higher learning rate (5e-4). This establishes basic next-token prediction capability.
  • Stage 2 (Long-Context Fine-Tuning): Load Stage 1 draft weights and fine-tune on long-context data (32768 tokens). Uses ZeRO-3 for the larger memory footprint and lower learning rate (5e-6). This extends the draft model's effective context window.
  • Stage 3 (Long Chain-of-Thought): Load Stage 2 draft weights and fine-tune on long chain-of-thought data. Same infrastructure as Stage 2 but specialized data. This optimizes the draft model for reasoning tasks like AIME mathematical problems.

Each stage produces a draft_model_weights.pth file that serves as input for the next stage. The target LLM remains frozen throughout all stages.

An alternative sinkpi-slicing training variant is available at each stage, providing a different attention pattern for the draft model.

Usage

Use multi-stage training when training GLIDE draft models for long-context or reasoning applications. The stages must be run sequentially—each depends on the previous stage's output checkpoint. The full pipeline is:

  1. Run Stage 1 → produces draft_model_weights.pth
  2. Update Stage 2 config to point to Stage 1 output
  3. Run Stage 2 → produces updated draft_model_weights.pth
  4. Update Stage 3 config to point to Stage 2 output
  5. Run Stage 3 → produces final draft_model_weights.pth

Theoretical Basis

Multi-stage training follows curriculum learning principles:

  1. Start simple: Short contexts (1024 tokens) with high learning rate allow rapid convergence on basic patterns
  2. Increase difficulty: Longer contexts (32768 tokens) require the model to learn long-range dependencies, but starting from a good initialization
  3. Specialize: Task-specific data (Chain-of-Thought) optimizes for the target use case

The key configuration changes across stages:

Parameter Stage 1 Stage 2 Stage 3
Context length 1024 32768 32768
Learning rate 5e-4 5e-6 5e-6
ZeRO stage 1 3 3
Data type General text Long documents Chain-of-thought
draft_model_name_or_path None (fresh) Stage 1 output Stage 2 output

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment