Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Microsoft DeepSpeedExamples DeepSpeed Engine Init

From Leeroopedia


Metadata

Field Value
Page Type Principle
Repository Microsoft/DeepSpeedExamples
Title DeepSpeed_Engine_Init
Sources Paper: ZeRO: Memory Optimizations Toward Training Trillion Parameter Models, Doc: DeepSpeed Getting Started
Domains Distributed_Training, Deep_Learning
Related Implementation Implementation:Microsoft_DeepSpeedExamples_DeepSpeed_Initialize_CIFAR

Overview

A technique for wrapping a PyTorch model with DeepSpeed's distributed training engine that manages optimization, mixed precision, and communication.

Description

deepspeed.initialize() is the central API call in any DeepSpeed migration. It replaces PyTorch's manual optimizer and scheduler setup, DataLoader creation, and distributed data parallel wrapping with a single unified call. The function accepts a raw PyTorch nn.Module and returns a DeepSpeedEngine that transparently handles:

  • Distributed Data Parallelism -- Automatically wraps the model for multi-GPU/multi-node training with gradient synchronization across ranks
  • ZeRO Optimization -- Partitions optimizer states (Stage 1), gradients (Stage 2), and/or parameters (Stage 3) across data-parallel ranks to reduce per-GPU memory
  • Mixed Precision Training -- Manages FP16 or BF16 forward/backward passes with automatic loss scaling (for FP16) and master weight maintenance
  • Gradient Accumulation -- Handles micro-batch gradient accumulation when the effective batch size exceeds the micro-batch size
  • Learning Rate Scheduling -- Integrates the scheduler so that model_engine.step() handles both parameter updates and LR stepping
  • Distributed Data Loading -- Creates a distributed-aware DataLoader with proper sampling for multi-GPU training

The Initialize Call

The signature of deepspeed.initialize() is:

model_engine, optimizer, dataloader, lr_scheduler = deepspeed.initialize(
    args=args,              # CLI arguments (includes --local_rank, --deepspeed, etc.)
    model=model,            # Raw PyTorch nn.Module
    model_parameters=params, # Parameters to optimize (filter for requires_grad)
    training_data=dataset,  # PyTorch Dataset (DeepSpeed creates the DataLoader)
    config=ds_config,       # DeepSpeed JSON config dict or path to JSON file
)

What Gets Replaced

Standard PyTorch DeepSpeed Equivalent Handled By
optimizer = optim.SGD(...) Created internally deepspeed.initialize()
scheduler = lr_scheduler.StepLR(...) Created internally deepspeed.initialize()
DataLoader(dataset, ...) Returned from initialize() deepspeed.initialize()
model = DDP(model) Returned model_engine deepspeed.initialize()
optimizer.zero_grad() Handled internally model_engine.step()
loss.backward() model_engine.backward(loss) Gradient scaling, communication
optimizer.step() model_engine.step() Parameter update, LR step, ZeRO comm

Theoretical Basis

DeepSpeed Engine Composition

The DeepSpeed Engine is a composite object that unifies multiple training components:

DeepSpeedEngine = Model + Optimizer + Scheduler + Communication Backend

Engine.forward(x)        -->  Model.forward(x)   [with mixed precision casting]
Engine.backward(loss)     -->  loss.backward()    [with gradient scaling + ZeRO comm]
Engine.step()             -->  optimizer.step()   [with LR scheduling + ZeRO sync]

ZeRO Optimization Stages

The engine's behavior changes significantly based on the ZeRO stage:

Stage Partitions Memory Savings Communication Overhead
Stage 0 None (standard DDP) 1x AllReduce
Stage 1 Optimizer states ~4x AllReduce + AllGather
Stage 2 Optimizer states + Gradients ~8x ReduceScatter + AllGather
Stage 3 Optimizer states + Gradients + Parameters ~Nx (linear with N GPUs) AllGather for forward/backward + ReduceScatter

Mixed Precision Management

When FP16 is enabled, the engine:

  1. Maintains FP32 master weights in the optimizer
  2. Casts inputs to FP16 for forward/backward
  3. Applies dynamic loss scaling to prevent gradient underflow
  4. Accumulates gradients in FP16, converts to FP32 for optimizer step

When BF16 is enabled, the engine:

  1. Uses BF16 for forward/backward (no loss scaling needed due to larger exponent range)
  2. Maintains FP32 master weights for optimizer updates

Distributed Data Loading

deepspeed.initialize() creates a DataLoader with a DistributedSampler that:

  • Splits the training data evenly across all data-parallel ranks
  • Ensures no data duplication between ranks
  • Supports deterministic shuffling with epoch-based seeding

Configuration Structure

The JSON config passed to deepspeed.initialize() follows this structure:

{
    "train_batch_size": 16,           # Global batch size across all GPUs
    "steps_per_print": 2000,          # Logging interval
    "optimizer": {
        "type": "Adam",               # Optimizer class name
        "params": { ... }             # Optimizer hyperparameters
    },
    "scheduler": {
        "type": "WarmupLR",           # Scheduler class name
        "params": { ... }             # Scheduler hyperparameters
    },
    "gradient_clipping": 1.0,         # Max gradient norm
    "fp16": { "enabled": True, ... }, # FP16 settings
    "bf16": { "enabled": False },     # BF16 settings
    "zero_optimization": {
        "stage": 0,                   # ZeRO stage
        ...                           # Stage-specific settings
    }
}

Initialization Sequence

The full initialization sequence in the CIFAR-10 example:

1. deepspeed.init_distributed()     -- Initialize NCCL/Gloo backend
2. get_accelerator().set_device()   -- Pin current process to its GPU
3. Net(args)                        -- Create the raw PyTorch model
4. filter(requires_grad, params)    -- Get trainable parameters
5. get_ds_config(args)              -- Build DeepSpeed config dict
6. deepspeed.initialize(...)        -- Create the engine
7. model_engine.local_rank          -- Query engine for device info
8. model_engine.bfloat16_enabled()  -- Query engine for dtype info

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment