Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Transformers DCP Save

From Leeroopedia
Knowledge Sources
Domains Distributed_Computing, Training, Checkpointing
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete tool for saving distributed training checkpoints using PyTorch's Distributed Checkpoint (DCP) API as used in the Hugging Face Transformers 3D parallel training example.

Description

This wrapper uses torch.distributed.checkpoint.save (dcp.save) to write the distributed training state to a checkpoint directory. The state is packaged in an AppState object -- a Stateful wrapper that encapsulates both the model and optimizer.

The AppState class implements:

  • state_dict(): Uses get_state_dict(model, optimizer) to extract distributed-aware state dicts that preserve DTensor metadata, FSDP flat parameter structure, and TP placements.
  • load_state_dict(state_dict): Uses set_state_dict(model, optimizer, ...) to restore state from a checkpoint, handling redistribution across a potentially different mesh.

The checkpoint is saved to a directory named by the parallelism configuration (e.g., checkpoint_tp2_dp2_cp2). Each rank writes its local shard, and DCP coordinates the metadata to allow resharded loading.

For non-distributed training, the example falls back to the standard model.save_pretrained() API.

Usage

Use this wrapper at the end of training (or at periodic checkpoint intervals) when training in a distributed setting. It requires dist.is_initialized() to be True. The process group must remain active until after the save completes.

Code Reference

Source Location

  • Repository: transformers
  • File: examples/3D_parallel.py
  • Lines: 333-397 (save logic at 333-345, AppState class at 383-397)

Signature

dcp.save(
    state_dict={"app": AppState(model, optimizer)},
    checkpoint_id=CHECKPOINT_DIR,
)

Import

import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

AppState Class

class AppState(Stateful):
    """Wrapper for checkpointing the Application State including model and optimizer."""

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {"model": model_state_dict, "optim": optimizer_state_dict}

    def load_state_dict(self, state_dict):
        set_state_dict(
            self.model, self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"],
        )

I/O Contract

Inputs

Name Type Required Description
state_dict dict[str, Stateful] Yes A dictionary mapping string keys to Stateful objects. Typically {"app": AppState(model, optimizer)}.
checkpoint_id str Yes Path to the checkpoint directory, e.g. "checkpoint_tp2_dp2_cp2".

Inputs (AppState)

Name Type Required Description
model nn.Module Yes The model (may be FSDP-wrapped with TP-sharded parameters).
optimizer torch.optim.Optimizer No The optimizer whose state should be checkpointed alongside the model.

Outputs

Name Type Description
(side effect) None Writes checkpoint files to the checkpoint_id directory. Each rank writes its local shard.

Usage Examples

Basic Usage

import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful


class AppState(Stateful):
    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {"model": model_state_dict, "optim": optimizer_state_dict}

    def load_state_dict(self, state_dict):
        set_state_dict(
            self.model, self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"],
        )


# Save checkpoint after training
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
if dist.is_initialized():
    state_dict = {"app": AppState(model, optimizer)}
    dcp.save(state_dict=state_dict, checkpoint_id=CHECKPOINT_DIR)

Loading a Checkpoint

# Restore checkpoint (can use different parallelism topology)
if dist.is_initialized():
    state_dict = {"app": AppState(model, optimizer)}
    dcp.load(state_dict=state_dict, checkpoint_id=CHECKPOINT_DIR)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment