Implementation:Sail sg LongSpec DeepSpeed Train Loop
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Training |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Concrete tool for training GLIDE draft models using DeepSpeed-managed distributed training with ZeRO optimization, gradient accumulation, and multi-file dataset iteration.
Description
The train() function in trainer_base_ds_mul_fs_tp.py implements the complete training loop for GLIDE draft models. It wraps the model in a DeepSpeed engine, sets up distributed data loading, and iterates through training batches with gradient accumulation. The function also handles checkpointing, evaluation, and logging.
The companion forward_step() function performs a single forward-backward pass, returning the scalar loss and model outputs.
Usage
Called from the main() entry point after model initialization. Typically invoked indirectly via DeepSpeed launcher. Not called directly by users.
Code Reference
Source Location
- Repository: LongSpec
- File: longspec/train/trainer_base_ds_mul_fs_tp.py
- Lines: L116-334 (forward_step + train)
Signature
def forward_step(
model,
inputs: Dict[str, torch.Tensor],
) -> Tuple[float, Any]:
"""
Single forward pass returning loss scalar and outputs.
Args:
model: DeepSpeed-wrapped model
inputs: Dictionary of input tensors (input_ids, labels, attention_mask, etc.)
Returns:
Tuple of (loss.item(), model_outputs)
"""
def train(
cfg: DictConfig,
model,
tokenizer: PreTrainedTokenizer,
continue_from_global_step: int = 0,
) -> Tuple[int, float]:
"""
Complete training loop with DeepSpeed.
Args:
cfg: Hydra DictConfig with all training parameters
model: Qwen2Glide or LlamaGlide model instance
tokenizer: HuggingFace tokenizer
continue_from_global_step: Resume from this step (default 0)
Returns:
Tuple of (global_step, average_loss)
"""
Import
# The training functions are in the main trainer module:
from longspec.train.trainer_base_ds_mul_fs_tp import train, forward_step
# Key dependencies:
import deepspeed
from omegaconf import DictConfig
from transformers import PreTrainedTokenizer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Full Hydra configuration with training hyperparameters, data paths, DeepSpeed config |
| model | Qwen2Glide / LlamaGlide | Yes | Initialized GLIDE model (target frozen, draft trainable) |
| tokenizer | PreTrainedTokenizer | Yes | Tokenizer for saving with checkpoints |
| continue_from_global_step | int | No | Step to resume from (default: 0) |
Outputs
| Name | Type | Description |
|---|---|---|
| global_step | int | Final training step number |
| average_loss | float | Average training loss across all steps |
| checkpoints | Files (side effect) | Model checkpoints saved to cfg.output_dir at cfg.save_steps intervals |
| logs | WandB (side effect) | Training metrics logged to Weights & Biases |
Usage Examples
Standard Training Launch
# Launch via DeepSpeed on 8 GPUs (from train.sh):
deepspeed --num_gpus=8 trainer_base_ds_mul_fs_tp.py +exp=qwq_glide_8gpu_slim6b
Internal Call Flow
# In main() function (trainer_base_ds_mul_fs_tp.py:L337-463):
def main(cfg: DictConfig):
# 1. Set random seed
set_seed(cfg)
# 2. Initialize distributed environment
vanilla_torch_dist()
# 3. Initialize model via Hydra
model = hydra.utils.call(cfg.model, cfg.model_name_or_path)
# 4. Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path)
# 5. Run training
global_step, avg_loss = train(cfg, model, tokenizer)
Key Configuration Parameters
# From qwq_glide_8gpu_slim6b.yaml:
per_gpu_train_batch_size: 2
gradient_accumulation_steps: 128
learning_rate: 5e-4
num_train_epochs: 1
warmup_proportion: 0.1
save_steps: 100
eval_steps: -1 # No evaluation during training