Principle:Deepspeedai DeepSpeed ZeRO Parameter Partitioning
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, Memory_Optimization, Model_Parallelism |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A memory optimization technique that partitions model parameters across data-parallel ranks during model construction, enabling training of models that exceed single-GPU memory.
Description
ZeRO Parameter Partitioning (ZeRO Stage 3) partitions model parameters across all data-parallel GPUs during model construction. Instead of each GPU holding a complete copy of the model, each GPU only holds 1/N of the parameters (where N is the data-parallel world size). Parameters are gathered on-demand for forward and backward computation and immediately re-partitioned afterward.
The Init context manager intercepts torch.nn.Module.__init__ calls to automatically partition parameters as they are created. This means the model never fully materializes on any single GPU, enabling construction of models that are N times larger than what a single GPU can hold.
Key characteristics:
- Construction-time partitioning: Parameters are sharded immediately upon creation, before the model is fully constructed
- On-demand gathering: Full parameters are assembled via AllGather before each forward/backward operation
- Immediate release: Gathered parameters are discarded after use to reclaim memory
- Offloading support: Parameter shards can be offloaded to CPU or NVMe for further memory savings
Usage
Wrap model construction inside the deepspeed.zero.Init() context manager when using ZeRO Stage 3. This is required for models too large to fit in a single GPU's memory. For models that fit in memory, ZeRO Stage 3 is optional but still reduces per-GPU memory footprint.
Theoretical Basis
ZeRO-3 parameter partitioning: For a model with P parameters across N GPUs, each GPU stores P/N parameters. Full parameters are gathered via AllGather before each forward/backward operation and discarded after.
Memory per GPU:
- Parameters: O(P/N) persistent storage + O(P) temporary during computation
- Gradients: O(P/N) after reduce-scatter
- Optimizer states: O(P/N) (e.g., 2 states for Adam: momentum and variance)
Communication overhead:
- Forward pass: One AllGather per layer to reconstruct full parameters
- Backward pass: One AllGather per layer (for parameter gradients) + one ReduceScatter per layer (for gradient partitioning)
- Total volume: 3x the parameter size per training step (vs. 2x for ZeRO-2)
Memory reduction factor: With N GPUs, ZeRO-3 reduces the per-GPU memory footprint for model states from O(P * (K + 1)) to O(P * (K + 1) / N), where K is the number of optimizer state terms per parameter (K=2 for Adam).
Paper: ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
Pseudo-code:
# Abstract ZeRO-3 parameter partitioning pattern
with zero_init_context(world_size=N):
model = LargeModel() # Each param auto-sharded to 1/N
# During forward pass:
for layer in model.layers:
full_params = all_gather(layer.partitioned_params) # Reconstruct
output = layer.forward(input, full_params)
release(full_params) # Free gathered memory
# During backward pass:
for layer in reversed(model.layers):
full_params = all_gather(layer.partitioned_params)
grads = layer.backward(grad_output, full_params)
release(full_params)
partitioned_grads = reduce_scatter(grads) # Partition gradients