Implementation:Huggingface Transformers DCP Save
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Training, Checkpointing |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Concrete tool for saving distributed training checkpoints using PyTorch's Distributed Checkpoint (DCP) API as used in the Hugging Face Transformers 3D parallel training example.
Description
This wrapper uses torch.distributed.checkpoint.save (dcp.save) to write the distributed training state to a checkpoint directory. The state is packaged in an AppState object -- a Stateful wrapper that encapsulates both the model and optimizer.
The AppState class implements:
state_dict(): Usesget_state_dict(model, optimizer)to extract distributed-aware state dicts that preserve DTensor metadata, FSDP flat parameter structure, and TP placements.load_state_dict(state_dict): Usesset_state_dict(model, optimizer, ...)to restore state from a checkpoint, handling redistribution across a potentially different mesh.
The checkpoint is saved to a directory named by the parallelism configuration (e.g., checkpoint_tp2_dp2_cp2). Each rank writes its local shard, and DCP coordinates the metadata to allow resharded loading.
For non-distributed training, the example falls back to the standard model.save_pretrained() API.
Usage
Use this wrapper at the end of training (or at periodic checkpoint intervals) when training in a distributed setting. It requires dist.is_initialized() to be True. The process group must remain active until after the save completes.
Code Reference
Source Location
- Repository: transformers
- File:
examples/3D_parallel.py - Lines: 333-397 (save logic at 333-345, AppState class at 383-397)
Signature
dcp.save(
state_dict={"app": AppState(model, optimizer)},
checkpoint_id=CHECKPOINT_DIR,
)
Import
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
AppState Class
class AppState(Stateful):
"""Wrapper for checkpointing the Application State including model and optimizer."""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {"model": model_state_dict, "optim": optimizer_state_dict}
def load_state_dict(self, state_dict):
set_state_dict(
self.model, self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| state_dict | dict[str, Stateful] | Yes | A dictionary mapping string keys to Stateful objects. Typically {"app": AppState(model, optimizer)}.
|
| checkpoint_id | str | Yes | Path to the checkpoint directory, e.g. "checkpoint_tp2_dp2_cp2".
|
Inputs (AppState)
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | The model (may be FSDP-wrapped with TP-sharded parameters). |
| optimizer | torch.optim.Optimizer | No | The optimizer whose state should be checkpointed alongside the model. |
Outputs
| Name | Type | Description |
|---|---|---|
| (side effect) | None | Writes checkpoint files to the checkpoint_id directory. Each rank writes its local shard.
|
Usage Examples
Basic Usage
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
class AppState(Stateful):
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {"model": model_state_dict, "optim": optimizer_state_dict}
def load_state_dict(self, state_dict):
set_state_dict(
self.model, self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
)
# Save checkpoint after training
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
if dist.is_initialized():
state_dict = {"app": AppState(model, optimizer)}
dcp.save(state_dict=state_dict, checkpoint_id=CHECKPOINT_DIR)
Loading a Checkpoint
# Restore checkpoint (can use different parallelism topology)
if dist.is_initialized():
state_dict = {"app": AppState(model, optimizer)}
dcp.load(state_dict=state_dict, checkpoint_id=CHECKPOINT_DIR)