Implementation:Microsoft DeepSpeedExamples DeepSpeed Save Checkpoint
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation |
| Title | DeepSpeed_Save_Checkpoint |
| Repository | Microsoft/DeepSpeedExamples |
| Type | Wrapper Doc (wraps model_engine.save_checkpoint)
|
| Code Reference | File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 363-371
|
| Import | import deepspeed, from deepspeed import comm as dist
|
| Related Principle | Principle:Microsoft_DeepSpeedExamples_Distributed_Checkpoint_Saving |
Overview
Concrete usage of DeepSpeed's checkpoint saving for SuperOffload fine-tuned models. Wraps the model_engine.save_checkpoint() call with directory creation, tokenizer saving, and error handling.
Checkpoint Saving Implementation
Code Reference
File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 363-371
Implementation
if args.save_checkpoint and dist.get_rank() == 0:
try:
logger.debug(f"Saving model to {args.output_dir}...")
os.makedirs(args.output_dir, exist_ok=True)
model_engine.save_checkpoint(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
logger.debug("Model saved successfully!")
except Exception as e:
logger.error(f"Error saving model: {e}")
Execution Flow
| Step | Operation | Description |
|---|---|---|
| 1 | Guard check | Verify args.save_checkpoint is True and current process is rank 0
|
| 2 | Directory creation | os.makedirs(args.output_dir, exist_ok=True) -- create output directory if needed
|
| 3 | Model checkpoint | model_engine.save_checkpoint(args.output_dir) -- save DeepSpeed model and optimizer states
|
| 4 | Tokenizer save | tokenizer.save_pretrained(args.output_dir) -- save tokenizer config and vocabulary
|
| 5 | Error handling | Catch and log any exceptions without crashing the training process |
I/O Contract
Inputs
| Variable | Type | Description | Source |
|---|---|---|---|
args.save_checkpoint |
bool |
Whether to save a checkpoint | CLI flag --save_checkpoint
|
args.output_dir |
str |
Directory path for saving the checkpoint | CLI argument --output_dir
|
model_engine |
DeepSpeedEngine |
The trained DeepSpeed model engine | From deepspeed.initialize()
|
tokenizer |
AutoTokenizer |
The HuggingFace tokenizer | From load_tokenizer()
|
dist.get_rank() |
int |
Current process rank in distributed training | DeepSpeed distributed communication |
Outputs
The checkpoint saving produces the following files in args.output_dir:
| File/Directory | Source | Description |
|---|---|---|
global_step*/mp_rank_*_model_states.pt |
model_engine.save_checkpoint() |
Partitioned model parameter state |
global_step*/zero_pp_rank_*_optim_states.pt |
model_engine.save_checkpoint() |
Partitioned optimizer state (Adam momentum and variance) |
latest |
model_engine.save_checkpoint() |
Text file pointing to the latest checkpoint tag |
tokenizer.json |
tokenizer.save_pretrained() |
Tokenizer configuration |
tokenizer_config.json |
tokenizer.save_pretrained() |
Tokenizer settings (model_max_length, padding_side, etc.) |
special_tokens_map.json |
tokenizer.save_pretrained() |
Special token definitions (bos, eos, pad, etc.) |
model_engine.save_checkpoint()
The save_checkpoint method on the DeepSpeed engine performs the following operations internally:
- State collection -- Gathers the model state, optimizer state, and training metadata from the current rank.
- Directory creation -- Creates a subdirectory named
global_step{N}within the specified output directory. - State serialization -- Writes the partitioned model and optimizer states to disk using
torch.save(). - Tag management -- Updates the
latestfile to point to the current checkpoint.
Method Signature (DeepSpeed API)
model_engine.save_checkpoint(
save_dir: str,
tag: str = None, # Optional checkpoint tag (defaults to global_step)
client_state: dict = None, # Optional additional state to save
save_latest: bool = True # Whether to update the "latest" pointer
)
In the SuperOffload implementation, only save_dir is passed, using all defaults.
tokenizer.save_pretrained()
The tokenizer save is separate from the DeepSpeed checkpoint because the tokenizer is not part of the distributed model state. It saves:
- The tokenizer vocabulary and merge rules
- Configuration files for reloading the tokenizer
- Special token mappings
This ensures the tokenizer can be loaded alongside the model checkpoint for inference.
Guard Conditions
The checkpoint saving has two guard conditions:
args.save_checkpoint-- Must be explicitly enabled via the--save_checkpointCLI flag. By default, it isFalsein all launch scripts (set toSAVE_CHECKPOINT=false).dist.get_rank() == 0-- Only rank 0 initiates the save. This prevents duplicate directory creation and tokenizer saves. The DeepSpeed engine internally handles coordination across ranks for the model state.
Error Handling
The entire save operation is wrapped in a try-except block:
try:
# ... save operations ...
except Exception as e:
logger.error(f"Error saving model: {e}")
This ensures that checkpoint saving failures (disk full, permission errors, NFS issues) do not crash the training process. The error is logged but training continues (or completes normally if this is the final step).
Enabling Checkpoint Saving
To enable checkpoint saving in the launch scripts, change the SAVE_CHECKPOINT variable:
# In finetune_llama-8b_1gpu.sh (or any launch script)
SAVE_CHECKPOINT=true # Changed from false to true
This adds the --save_checkpoint flag to the DeepSpeed launch command.
Usage Example
import os
import deepspeed
from deepspeed import comm as dist
# After training loop completes
if save_checkpoint and dist.get_rank() == 0:
try:
output_dir = "./llama-8b_superoffload_output"
os.makedirs(output_dir, exist_ok=True)
# Save DeepSpeed checkpoint (model + optimizer states)
model_engine.save_checkpoint(output_dir)
# Save tokenizer for inference
tokenizer.save_pretrained(output_dir)
print("Checkpoint saved successfully!")
except Exception as e:
print(f"Error saving checkpoint: {e}")
WandB Cleanup
Immediately after checkpoint saving, the WandB run is finalized (Lines 373-378):
if args.use_wandb and dist.get_rank() == 0:
try:
wandb.finish()
logger.debug("WandB run finished successfully")
except Exception as e:
logger.error(f"Error finishing WandB run: {e}")
This ensures that all logged metrics are flushed and the WandB run is properly closed.