Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sail sg LongSpec DeepSpeed Train Loop

From Leeroopedia
Knowledge Sources
Domains Distributed_Computing, Training
Last Updated 2026-02-14 05:00 GMT

Overview

Concrete tool for training GLIDE draft models using DeepSpeed-managed distributed training with ZeRO optimization, gradient accumulation, and multi-file dataset iteration.

Description

The train() function in trainer_base_ds_mul_fs_tp.py implements the complete training loop for GLIDE draft models. It wraps the model in a DeepSpeed engine, sets up distributed data loading, and iterates through training batches with gradient accumulation. The function also handles checkpointing, evaluation, and logging.

The companion forward_step() function performs a single forward-backward pass, returning the scalar loss and model outputs.

Usage

Called from the main() entry point after model initialization. Typically invoked indirectly via DeepSpeed launcher. Not called directly by users.

Code Reference

Source Location

  • Repository: LongSpec
  • File: longspec/train/trainer_base_ds_mul_fs_tp.py
  • Lines: L116-334 (forward_step + train)

Signature

def forward_step(
    model,
    inputs: Dict[str, torch.Tensor],
) -> Tuple[float, Any]:
    """
    Single forward pass returning loss scalar and outputs.

    Args:
        model: DeepSpeed-wrapped model
        inputs: Dictionary of input tensors (input_ids, labels, attention_mask, etc.)

    Returns:
        Tuple of (loss.item(), model_outputs)
    """

def train(
    cfg: DictConfig,
    model,
    tokenizer: PreTrainedTokenizer,
    continue_from_global_step: int = 0,
) -> Tuple[int, float]:
    """
    Complete training loop with DeepSpeed.

    Args:
        cfg: Hydra DictConfig with all training parameters
        model: Qwen2Glide or LlamaGlide model instance
        tokenizer: HuggingFace tokenizer
        continue_from_global_step: Resume from this step (default 0)

    Returns:
        Tuple of (global_step, average_loss)
    """

Import

# The training functions are in the main trainer module:
from longspec.train.trainer_base_ds_mul_fs_tp import train, forward_step

# Key dependencies:
import deepspeed
from omegaconf import DictConfig
from transformers import PreTrainedTokenizer

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Full Hydra configuration with training hyperparameters, data paths, DeepSpeed config
model Qwen2Glide / LlamaGlide Yes Initialized GLIDE model (target frozen, draft trainable)
tokenizer PreTrainedTokenizer Yes Tokenizer for saving with checkpoints
continue_from_global_step int No Step to resume from (default: 0)

Outputs

Name Type Description
global_step int Final training step number
average_loss float Average training loss across all steps
checkpoints Files (side effect) Model checkpoints saved to cfg.output_dir at cfg.save_steps intervals
logs WandB (side effect) Training metrics logged to Weights & Biases

Usage Examples

Standard Training Launch

# Launch via DeepSpeed on 8 GPUs (from train.sh):
deepspeed --num_gpus=8 trainer_base_ds_mul_fs_tp.py +exp=qwq_glide_8gpu_slim6b

Internal Call Flow

# In main() function (trainer_base_ds_mul_fs_tp.py:L337-463):
def main(cfg: DictConfig):
    # 1. Set random seed
    set_seed(cfg)

    # 2. Initialize distributed environment
    vanilla_torch_dist()

    # 3. Initialize model via Hydra
    model = hydra.utils.call(cfg.model, cfg.model_name_or_path)

    # 4. Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path)

    # 5. Run training
    global_step, avg_loss = train(cfg, model, tokenizer)

Key Configuration Parameters

# From qwq_glide_8gpu_slim6b.yaml:
per_gpu_train_batch_size: 2
gradient_accumulation_steps: 128
learning_rate: 5e-4
num_train_epochs: 1
warmup_proportion: 0.1
save_steps: 100
eval_steps: -1  # No evaluation during training

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment