Implementation:Microsoft DeepSpeedExamples Offload States Demo
| Knowledge Sources | |
|---|---|
| Domains | Memory Optimization, DeepSpeed ZeRO |
| Last Updated | 2026-02-07 12:00 GMT |
Overview
Demonstrates the DeepSpeed offload states API for ZeRO Stage 3, with memory profiling to measure offload and reload performance under various configurations.
Description
This script provides an end-to-end demonstration of the DeepSpeed offload_states and reload_states API for ZeRO Stage 3 training. It defines a SimpleModel consisting of configurable linear layers and a cross-entropy loss, along with helper functions for generating random training data. The run_model function initializes a DeepSpeed engine, runs training iterations, and after each iteration offloads model states to CPU memory via model.offload_states() and reloads them via model.reload_states().
The script profiles both offload and reload times, verifying through assertions that GPU memory allocation decreases after offloading and increases after reloading. It supports configurable pin_memory and non_blocking transfer options, as well as selective state offloading via the include parameter (which accepts specific OffloadStateTypeEnum values). After excluding warmup iterations, the script logs average offload and reload times to a file.
The main function provides a complete CLI interface with arguments for specifying the included state type, pin_memory, non_blocking, number of model layers, hidden dimension, data type (bfloat16/float16/float32), and iteration/warmup counts. The DeepSpeed configuration uses ZeRO Stage 3 with Adam optimizer.
Usage
Use this script to benchmark and validate the DeepSpeed offload states API in ZeRO Stage 3 scenarios. It is useful for understanding memory savings from state offloading, comparing pin_memory vs. non-pinned transfers, and measuring offload/reload latency for capacity planning.
Code Reference
Source Location
- Repository: Microsoft_DeepSpeedExamples
- File: training/offload_states/offload_states.py
- Lines: 1-152
Signature
class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
...
def forward(self, x, y):
...
def random_dataset(total_samples, hidden_dim, device, dtype):
...
def random_dataloader(model, total_samples, hidden_dim, device, dtype):
...
def run_model(model, config_dict, hidden_dim, dtype, include,
pin_memory, non_blocking, iteration, warmup):
...
def main():
...
Import
from offload_states import SimpleModel, run_model, random_dataset, random_dataloader
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| --included_state | str | No | State type to include in offload (e.g., 'hp_params', 'lp_params', 'lp_grads', 'contiguous_grad_buffer') |
| --pin_memory | flag | No | Enable pinned memory for CPU offload transfers |
| --non_blocking | flag | No | Enable non-blocking CUDA transfers |
| --nlayers | int | No | Number of linear layers in the model (default: 1) |
| --hidden_dim | int | No | Hidden dimension size (default: 1024) |
| --dtype | str | No | Data type: 'torch.bfloat16', 'torch.float16', or 'torch.float32' (default: 'torch.bfloat16') |
| --iteration | int | No | Total number of training iterations (default: 10) |
| --warmup | int | No | Number of warmup iterations to exclude from timing (default: 5) |
Outputs
| Name | Type | Description |
|---|---|---|
| offload_states.log | file | CSV log file with pin_memory, non_blocking, offload_time, and load_time per run |
| stdout | text | Per-iteration memory usage and summary timing information |
Usage Examples
# Run from command line with DeepSpeed launcher
# deepspeed offload_states.py --nlayers 4 --hidden_dim 2048 --pin_memory --non_blocking
# Programmatic usage
import torch
from offload_states import SimpleModel, run_model
model = SimpleModel(hidden_dim=1024, nlayers=2)
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"optimizer": {"type": "Adam", "params": {"lr": 1e-6}},
"zero_optimization": {"stage": 3},
"bf16": {"enabled": True},
}
run_model(model, config_dict, hidden_dim=1024, dtype=torch.bfloat16,
include=None, pin_memory=True, non_blocking=True,
iteration=10, warmup=5)