Principle:Microsoft DeepSpeedExamples Large Model Loading
Metadata
| Field | Value |
|---|---|
| Page Type | Principle |
| Title | Large_Model_Loading |
| Repository | Microsoft/DeepSpeedExamples |
| Sources | Doc: HuggingFace Model Loading https://huggingface.co/docs/transformers/main_classes/model ; Paper: Flash Attention https://arxiv.org/abs/2205.14135 |
| Domains | NLP, Model_Architecture, Memory_Optimization |
| Status | Active |
| Related Implementation | Implementation:Microsoft_DeepSpeedExamples_Load_Model_SuperOffload |
Overview
A technique for loading large language models with optimized attention backends and gradient checkpointing for memory-efficient fine-tuning.
Description
Loading models for fine-tuning at the 8B-70B+ parameter scale requires careful configuration of multiple components that interact to determine peak memory usage and training throughput. The key configuration decisions are:
- Attention implementation selection -- Choose between
eager(standard PyTorch attention),sdpa(Scaled Dot-Product Attention via PyTorch 2.0+), orflash_attention_2(Flash Attention 2). The choice affects both memory consumption and computation speed. - Gradient checkpointing -- Trade additional computation for reduced activation memory by recomputing intermediate activations during the backward pass instead of storing them.
- KV cache disabling -- The key-value cache is used during inference for autoregressive generation but is unnecessary and wasteful during training. It must be disabled.
- BF16 precision -- Using
torch.bfloat16reduces parameter memory by 50% compared to FP32, while maintaining the same dynamic range as FP32 (unlike FP16 which has a narrower range). - Tokenizer configuration -- Ensuring the
pad_tokenis set (defaulting toeos_tokenif not defined) for proper padding during batch construction.
Theoretical Basis
Flash Attention 2
Standard attention computes softmax(QK^T / sqrt(d)) * V and materializes the full N x N attention matrix, resulting in:
- Memory complexity: O(N^2) for the attention matrix
- I/O complexity: O(N^2 * d) for reading/writing the attention matrix to HBM
Flash Attention 2 uses a tiling strategy that processes the attention computation in blocks, never materializing the full attention matrix:
- Memory complexity: O(N) -- only stores the current block and running statistics
- I/O complexity: O(N^2 * d^2 / M) where M is SRAM size -- significantly reduced HBM reads/writes
- Speed improvement: 2-4x faster than standard attention due to reduced HBM traffic
- Numerical equivalence: Produces mathematically identical results using online softmax rescaling
Gradient Checkpointing
Without gradient checkpointing, all intermediate activations from the forward pass must be stored for the backward pass:
- Standard memory: O(L) where L is the number of layers -- all layer activations stored
- With checkpointing: O(sqrt(L)) -- only checkpoint activations at selected layers; recompute intermediate activations during backward
The trade-off is approximately 33% additional forward compute (one extra forward pass per checkpointed segment) in exchange for significantly reduced activation memory. For large models, this is essential because activation memory can dominate total memory usage.
BF16 Precision
| Precision | Bytes/Parameter | Dynamic Range | Mantissa Bits |
|---|---|---|---|
| FP32 | 4 | ~3.4e38 | 23 |
| FP16 | 2 | ~6.5e4 | 10 |
| BF16 | 2 | ~3.4e38 | 7 |
BF16 provides the same dynamic range as FP32 (avoiding the overflow/underflow issues of FP16) while using half the memory. The reduced mantissa precision (7 bits vs. 23) is acceptable for training because gradient updates are inherently noisy.
Model Loading Flow
The model loading process consists of three sequential steps:
- Load tokenizer -- Initialize
AutoTokenizerfrom the model name, setpad_token = eos_tokenif needed. - Load model -- Initialize
AutoModelForCausalLMwithtorch_dtype=torch.bfloat16and the chosen attention implementation. - Configure for training -- Enable gradient checkpointing with
use_reentrant=False, disable KV cache.
MoE Model Detection
For Mixture-of-Experts (MoE) models, additional configuration is needed. The loading process detects MoE models by checking for config attributes:
num_local_expertsmoe_layersnum_expertsexpert_capacityrouter_aux_loss_coef
When an MoE model is detected, a leaf_module can be set via set_z3_leaf_modules to enable proper ZeRO-3 partitioning of expert layers.
Usage Pattern
- Select attention implementation (default:
flash_attention_2). - Load tokenizer and model using HuggingFace
from_pretrained. - Enable gradient checkpointing and disable KV cache.
- Pass the model to DeepSpeed initialization.