Principle:Sail sg LongSpec Multi Stage Training
| Knowledge Sources | |
|---|---|
| Domains | Training, Curriculum_Learning, Long_Context |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Training strategy that progressively builds draft model capability across multiple stages, increasing context length and specializing for long chain-of-thought reasoning.
Description
Multi-Stage Training for GLIDE draft models follows a curriculum learning approach where each stage builds upon the previous one:
- Stage 1 (Base Training): Train the GLIDE draft layer from scratch on general-purpose text (SlimPajama-6B) with short contexts (1024 tokens). Uses ZeRO-1 optimization and higher learning rate (5e-4). This establishes basic next-token prediction capability.
- Stage 2 (Long-Context Fine-Tuning): Load Stage 1 draft weights and fine-tune on long-context data (32768 tokens). Uses ZeRO-3 for the larger memory footprint and lower learning rate (5e-6). This extends the draft model's effective context window.
- Stage 3 (Long Chain-of-Thought): Load Stage 2 draft weights and fine-tune on long chain-of-thought data. Same infrastructure as Stage 2 but specialized data. This optimizes the draft model for reasoning tasks like AIME mathematical problems.
Each stage produces a draft_model_weights.pth file that serves as input for the next stage. The target LLM remains frozen throughout all stages.
An alternative sinkpi-slicing training variant is available at each stage, providing a different attention pattern for the draft model.
Usage
Use multi-stage training when training GLIDE draft models for long-context or reasoning applications. The stages must be run sequentially—each depends on the previous stage's output checkpoint. The full pipeline is:
- Run Stage 1 → produces draft_model_weights.pth
- Update Stage 2 config to point to Stage 1 output
- Run Stage 2 → produces updated draft_model_weights.pth
- Update Stage 3 config to point to Stage 2 output
- Run Stage 3 → produces final draft_model_weights.pth
Theoretical Basis
Multi-stage training follows curriculum learning principles:
- Start simple: Short contexts (1024 tokens) with high learning rate allow rapid convergence on basic patterns
- Increase difficulty: Longer contexts (32768 tokens) require the model to learn long-range dependencies, but starting from a good initialization
- Specialize: Task-specific data (Chain-of-Thought) optimizes for the target use case
The key configuration changes across stages:
| Parameter | Stage 1 | Stage 2 | Stage 3 |
|---|---|---|---|
| Context length | 1024 | 32768 | 32768 |
| Learning rate | 5e-4 | 5e-6 | 5e-6 |
| ZeRO stage | 1 | 3 | 3 |
| Data type | General text | Long documents | Chain-of-thought |
| draft_model_name_or_path | None (fresh) | Stage 1 output | Stage 2 output |