Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Microsoft DeepSpeedExamples Large Model Loading

From Leeroopedia


Metadata

Field Value
Page Type Principle
Title Large_Model_Loading
Repository Microsoft/DeepSpeedExamples
Sources Doc: HuggingFace Model Loading https://huggingface.co/docs/transformers/main_classes/model ; Paper: Flash Attention https://arxiv.org/abs/2205.14135
Domains NLP, Model_Architecture, Memory_Optimization
Status Active
Related Implementation Implementation:Microsoft_DeepSpeedExamples_Load_Model_SuperOffload

Overview

A technique for loading large language models with optimized attention backends and gradient checkpointing for memory-efficient fine-tuning.

Description

Loading models for fine-tuning at the 8B-70B+ parameter scale requires careful configuration of multiple components that interact to determine peak memory usage and training throughput. The key configuration decisions are:

  • Attention implementation selection -- Choose between eager (standard PyTorch attention), sdpa (Scaled Dot-Product Attention via PyTorch 2.0+), or flash_attention_2 (Flash Attention 2). The choice affects both memory consumption and computation speed.
  • Gradient checkpointing -- Trade additional computation for reduced activation memory by recomputing intermediate activations during the backward pass instead of storing them.
  • KV cache disabling -- The key-value cache is used during inference for autoregressive generation but is unnecessary and wasteful during training. It must be disabled.
  • BF16 precision -- Using torch.bfloat16 reduces parameter memory by 50% compared to FP32, while maintaining the same dynamic range as FP32 (unlike FP16 which has a narrower range).
  • Tokenizer configuration -- Ensuring the pad_token is set (defaulting to eos_token if not defined) for proper padding during batch construction.

Theoretical Basis

Flash Attention 2

Standard attention computes softmax(QK^T / sqrt(d)) * V and materializes the full N x N attention matrix, resulting in:

  • Memory complexity: O(N^2) for the attention matrix
  • I/O complexity: O(N^2 * d) for reading/writing the attention matrix to HBM

Flash Attention 2 uses a tiling strategy that processes the attention computation in blocks, never materializing the full attention matrix:

  • Memory complexity: O(N) -- only stores the current block and running statistics
  • I/O complexity: O(N^2 * d^2 / M) where M is SRAM size -- significantly reduced HBM reads/writes
  • Speed improvement: 2-4x faster than standard attention due to reduced HBM traffic
  • Numerical equivalence: Produces mathematically identical results using online softmax rescaling

Gradient Checkpointing

Without gradient checkpointing, all intermediate activations from the forward pass must be stored for the backward pass:

  • Standard memory: O(L) where L is the number of layers -- all layer activations stored
  • With checkpointing: O(sqrt(L)) -- only checkpoint activations at selected layers; recompute intermediate activations during backward

The trade-off is approximately 33% additional forward compute (one extra forward pass per checkpointed segment) in exchange for significantly reduced activation memory. For large models, this is essential because activation memory can dominate total memory usage.

BF16 Precision

Precision Bytes/Parameter Dynamic Range Mantissa Bits
FP32 4 ~3.4e38 23
FP16 2 ~6.5e4 10
BF16 2 ~3.4e38 7

BF16 provides the same dynamic range as FP32 (avoiding the overflow/underflow issues of FP16) while using half the memory. The reduced mantissa precision (7 bits vs. 23) is acceptable for training because gradient updates are inherently noisy.

Model Loading Flow

The model loading process consists of three sequential steps:

  1. Load tokenizer -- Initialize AutoTokenizer from the model name, set pad_token = eos_token if needed.
  2. Load model -- Initialize AutoModelForCausalLM with torch_dtype=torch.bfloat16 and the chosen attention implementation.
  3. Configure for training -- Enable gradient checkpointing with use_reentrant=False, disable KV cache.

MoE Model Detection

For Mixture-of-Experts (MoE) models, additional configuration is needed. The loading process detects MoE models by checking for config attributes:

  • num_local_experts
  • moe_layers
  • num_experts
  • expert_capacity
  • router_aux_loss_coef

When an MoE model is detected, a leaf_module can be set via set_z3_leaf_modules to enable proper ZeRO-3 partitioning of expert layers.

Usage Pattern

  1. Select attention implementation (default: flash_attention_2).
  2. Load tokenizer and model using HuggingFace from_pretrained.
  3. Enable gradient checkpointing and disable KV cache.
  4. Pass the model to DeepSpeed initialization.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment