Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:Sail sg LongSpec GLIDE Draft Model Training

From Leeroopedia
Knowledge Sources
Domains LLMs, Speculative_Decoding, Distributed_Training
Last Updated 2025-07-01 00:00 GMT

Overview

End-to-end multi-stage training pipeline for the GLIDE (Global-Local Informed Draft Engine) draft model used in LongSpec speculative decoding, progressing from base training on short sequences through long-context fine-tuning to long chain-of-thought adaptation.

Description

This workflow trains a lightweight draft model that accelerates inference for large language models (Llama, Qwen2) via speculative decoding. The GLIDE draft model uses a single cross-attention layer to access the target LLM's frozen KV cache, combined with sliding-window self-attention for local context. Training proceeds through three progressive stages:

Stage 1 (Base): Teacher-forced next-token prediction on SlimPajama-6B data with short sequences (1024 tokens), using DeepSpeed ZeRO-1. This teaches the draft model to approximate the target LLM's output distribution.

Stage 2 (Long-Context): Fine-tuning on long-context data (up to 32k tokens) with DeepSpeed ZeRO-3, reducing the learning rate and switching to a long-context-aware collator. This bridges the distribution gap between short training data and long-context inference.

Stage 3 (Long-CoT): Further fine-tuning on long chain-of-thought data (e.g., QwQ-LongCoT-130K), teaching the draft model to handle extended reasoning traces typical of mathematical and logical reasoning tasks.

Each stage loads the checkpoint from the previous stage and uses Hydra-based composable YAML configurations for reproducibility.

Usage

Execute this workflow when you need to train a new GLIDE draft model for a specific target LLM (Llama or Qwen2 family). This is required when no pre-trained LongSpec weights exist for your target model, or when you want to customize the draft model for a specific domain or context length. Prerequisites include access to multi-GPU infrastructure (8x A100 GPUs recommended), the target model weights, and appropriate training datasets for each stage.

Execution Steps

Step 1: Environment_Setup

Install all required dependencies including PyTorch (>=2.6.0), DeepSpeed, Flash Attention, Hydra, wandb, and fairscale. Authenticate with Weights and Biases for experiment tracking. Ensure 8 GPUs are available and NCCL is properly configured for distributed communication.

Key considerations:

  • Python >= 3.12 is required
  • Flash Attention 2 must be installed for efficient attention computation
  • DeepSpeed and fairscale provide the distributed training infrastructure
  • wandb login must be completed before launching training

Step 2: Data_Preparation

Prepare training datasets in JSONL format for each stage. Stage 1 uses SlimPajama-6B general text data. Stage 2 uses long-context data (e.g., long-data/train_data_v2.jsonl). Stage 3 uses long chain-of-thought data (e.g., QwQ-LongCoT-130K). Each entry follows a source/target structure that the data collator transforms into model-ready format using configurable aligners and templates.

Key considerations:

  • Each stage requires its own dataset tailored to the training objective
  • Data paths must be configured in the corresponding YAML experiment config
  • The input aligner pipeline handles schema normalization (e.g., add_id_aligner)
  • The template system composes chat-format prompts from raw fields

Step 3: Configuration_Selection

Select and customize the Hydra YAML configuration for the target training stage. Each stage uses a different experiment config that specifies the DeepSpeed strategy (ZeRO stage), data collator, learning rate, sequence length, and checkpoint path. Stage 1 uses ZeRO-1 with optimizer offload and 1024-token sequences. Stage 2 switches to ZeRO-3 with 32k-token sequences. Stage 3 continues with ZeRO-3 and a chain-of-thought-specific collator.

Key considerations:

  • Stage 1 config: qwq_glide_8gpu_slim6b.yaml (ZeRO-1, lr=5e-4, seq_len=1024)
  • Stage 2 config: qwq_glide_8gpu_slim6b_longv2-32k-zero3_5e-6-ligce-nomask.yaml (ZeRO-3, lr=5e-6, seq_len=32768)
  • Stage 3 config: adds _longcot_5e-6 suffix (ZeRO-3, lr=5e-6, seq_len=32768, LongCoT collator)
  • A sinkpi-slicing variant is available at each stage as an alternative training strategy
  • model_name_or_path must point to the previous stage checkpoint for stages 2 and 3

Step 4: Model_Initialization

Load the target LLM (e.g., QwQ-32B-Preview) with its GLIDE draft model components. The model is instantiated via Hydra using Qwen2Glide.from_pretrained (or LlamaGlide for Llama models) with Flash Attention 2 enabled and bfloat16 precision. For stages 2 and 3, the draft model weights from the previous stage checkpoint are loaded. The target LLM parameters remain frozen while only the draft model parameters are trained.

Key considerations:

  • The GLIDE model wraps the target LLM and adds cross-attention and self-attention layers as the draft model
  • ignore_mismatched_sizes is set to True to handle the draft model's additional parameters
  • Optional: a separate draft_model_name_or_path can load pre-existing draft weights
  • The draft model is much smaller than the target (single cross-attention layer)

Step 5: Distributed_Training_Launch

Launch multi-GPU training using DeepSpeed with the selected configuration. The trainer initializes distributed communication via NCCL, sets up data-parallel and optional tensor-parallel training, creates DataLoaders with distributed samplers, and begins the training loop. Each forward pass computes the GLIDE training loss (teacher-forced next-token prediction against the target LLM outputs), followed by backward pass and optimizer step managed by DeepSpeed.

Key considerations:

  • Launch command: deepspeed --include localhost:0,1,2,3,4,5,6,7 ./trainer_base_ds_mul_fs_tp.py -cp conf/exp/ -cn [config_name]
  • Training logs and metrics are tracked via wandb
  • Checkpoints are saved at regular intervals (default: every 200 steps)
  • The training loop supports resume from latest checkpoint
  • Gradient accumulation handles large effective batch sizes (128 for Stage 1, 8-32 for later stages)

Step 6: Checkpoint_Extraction

After training completes, extract the draft model weights from the combined model checkpoint. The save_model function filters state dict keys with the "draft_model." prefix, strips the prefix, and saves the extracted weights as draft_model_weights.pth. For ZeRO-3 training stages, the consolidated 16-bit state dict is first gathered across all ranks before extraction.

Key considerations:

  • Only draft_model.* parameters are extracted; target LLM weights are discarded
  • ZeRO-3 requires special consolidation via _zero3_consolidated_16bit_state_dict()
  • The tokenizer and training config are also saved alongside the weights
  • Extracted weights can be uploaded to Hugging Face Hub for distribution

Step 7: Stage_Progression

After completing each stage, update the next stage's configuration to point to the latest checkpoint directory. The model_name_or_path in the Stage 2 config references the Stage 1 output (e.g., checkpoint-last/). Similarly, Stage 3 references the Stage 2 output. Then repeat Steps 3-6 for each subsequent stage until all three stages are complete.

Key considerations:

  • Each stage builds on the previous stage's draft model weights
  • Stage progression: Base (SlimPajama) -> Long-Context (32k data) -> Long-CoT (reasoning data)
  • The sinkpi-slicing variant follows its own three-stage progression path
  • Final checkpoint is ready for use in speculative decoding inference

Execution Diagram

GitHub URL

Workflow Repository