Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed ZeRO3 API Header

From Leeroopedia


Knowledge Sources
Domains Graph_Compilation, ZeRO_Optimization, API_Interface
Last Updated 2026-02-09 00:00 GMT

Overview

ZeRO3 API Header declares the public interface for ZeRO Stage 3 DeepCompile operations including parameter allgather, prefetch, release, offload, and lifecycle management.

Description

The z3.h header file provides the complete public API surface for ZeRO Stage 3 operations within DeepSpeed's graph compilation framework. It declares:

  • Parameter Management: Functions for registering ZeRO-3 partitioned parameters and graph executors
  • Allgather Operations: Both single-parameter (allgather_param) and batched (prefetch_params_fused) parameter gathering
  • Memory Management: Parameter release, persistence control, and memory invalidation for profiling
  • Synchronization: Wait operations for allgather completion
  • Offload/Reload: CPU offload and reload operations for both activations and parameters
  • Lifecycle Hooks: End-of-backward coordination
  • Meta Variants: Shape-only implementations for torch.compile's symbolic tracing

All functions are declared within the dc namespace and designed to be called from both C++ implementations and Python bindings via pybind11.

Usage

This header is included by z3.cpp (implementation), init.cpp (registration), and potentially user code needing direct C++ access to ZeRO-3 operations.

Code Reference

Source Location

Signature

#pragma once
#include "deepcompile.h"

namespace dc {

// Registration functions
void register_graph_z3(long graph_id, const std::vector<long>& ds_ids);

void register_graph_ops_z3(long graph_id,
                           const std::vector<std::string>& op_names,
                           const std::vector<long>& n_args);

void register_bwd_graph_ops_z3(long graph_id,
                               const std::vector<std::string>& op_names,
                               const std::vector<long>& n_args);

void register_z3_param(long ds_id,
                       const std::vector<int64_t>& ds_shape,
                       at::Tensor ds_tensor,
                       at::Tensor grad_buffer,
                       bool persistent);

// Parameter operations
at::Tensor allgather_param(at::Tensor param_tensor,
                           long graph_id,
                           long ds_id,
                           std::optional<at::ScalarType> dtype);

void set_persistent(long ds_id);

void prefetch_params_fused(long graph_id,
                           const std::vector<at::Tensor>& params,
                           const std::vector<long>& ds_ids,
                           const std::optional<std::vector<at::ScalarType>>& dtypes);

at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users);

at::Tensor wait_allgather(at::Tensor v, long graph_id, const long ds_id);

// Profiling support
void invalidate_gathered_param(long ds_id);
void clear_all_gathered_params();

// Meta variants (shape inference only)
at::Tensor allgather_param_meta(at::Tensor param_tensor,
                                long graph_id,
                                long ds_id,
                                std::optional<at::ScalarType> dtype);

void prefetch_params_fused_meta(long graph_id,
                                const std::vector<at::Tensor>& params,
                                const std::vector<long>& ds_ids,
                                const std::optional<std::vector<at::ScalarType>>& dtypes);

at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users);

at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id);

// Offload/Reload operations
at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id);
at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id);
at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id);
at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id);

void reload_parameter(at::Tensor tensor, long graph_id, long id);
void offload_parameter(at::Tensor tensor, long graph_id, long id);
void reload_parameter_meta(at::Tensor tensor, long graph_id, long id);
void offload_parameter_meta(at::Tensor tensor, long graph_id, long id);

// Lifecycle
void end_backward(long graph_id);

}  // namespace dc

Import

// C++ usage
#include "z3.h"

// Register parameter
dc::register_z3_param(0, {1024, 768}, param_shard, grad_buffer, false);

// Register graph
dc::register_graph_z3(0, {0, 1, 2});

// Allgather in C++
at::Tensor gathered = dc::allgather_param(shard, 0, 0, std::nullopt);
# Python usage (via pybind11 bindings)
from deepspeed.ops import dc

# Register parameter
dc.register_z3_param(0, [1024, 768], param_shard, grad_buffer, False)

# Register graph
dc.register_graph_z3(0, [0, 1, 2])

# Allgather via custom operator
gathered = torch.ops.dc.allgather_param(shard, 0, 0, None)

I/O Contract

Registration Functions

Function Parameters Description
register_graph_z3 (long, vector<long>) Creates Z3CustomOpExecutor for graph with parameter IDs
register_graph_ops_z3 (long, vector<string>, vector<long>) Registers forward graph operations (unused in current impl)
register_bwd_graph_ops_z3 (long, vector<string>, vector<long>) Registers backward graph operations (unused in current impl)
register_z3_param (long, vector<int64_t>, Tensor, Tensor, bool) Registers partitioned parameter with validation

Parameter Operations

Function Parameters Returns Description
allgather_param (Tensor, long, long, optional<ScalarType>) Tensor Gathers partitioned parameter from all ranks
set_persistent (long) void Marks parameter as persistent (keep gathered)
prefetch_params_fused (long, vector<Tensor>, vector<long>, optional<vector<ScalarType>>) void Batches multiple allgathers
release_param (Tensor, long, long, long) Tensor Decrements use count, frees non-persistent params
wait_allgather (Tensor, long, long) Tensor Synchronizes allgather stream

Profiling Functions

Function Parameters Description
invalidate_gathered_param (long) Invalidates gathered param for single parameter (profiling)
clear_all_gathered_params () Clears all non-persistent gathered parameters

Meta Variants

Function Purpose
allgather_param_meta Returns empty tensor with correct shape/dtype for symbolic tracing
prefetch_params_fused_meta No-op for symbolic tracing
release_param_meta Pass-through for symbolic tracing
wait_allgather_meta Pass-through for symbolic tracing
reload_parameter_meta No-op for symbolic tracing
offload_parameter_meta No-op for symbolic tracing

Offload/Reload Operations

Function Parameters Returns Description
offload_tensor (Tensor, long, long) Tensor Offloads activation to CPU pinned memory
reload_tensor (Tensor, long, long) Tensor Reloads activation from CPU to GPU
wait_offload (Tensor, long, long) Tensor Waits for offload completion
wait_reload (Tensor, long, long) Tensor Waits for reload completion
offload_parameter (Tensor, long, long) void Offloads parameter to CPU
reload_parameter (Tensor, long, long) void Reloads parameter from CPU

Lifecycle

Function Parameters Description
end_backward (long) Flushes reduce buckets, synchronizes streams

Usage Examples

// C++ usage example
#include "z3.h"
#include <torch/torch.h>

// Register ZeRO-3 parameters
std::vector<long> param_ids;
for (int i = 0; i < num_params; i++) {
    dc::register_z3_param(
        i,                          // ds_id
        {1024, 768},               // shape
        param_shards[i],           // partitioned tensor
        grad_buffers[i],           // gradient buffer
        false                      // persistent
    );
    param_ids.push_back(i);
}

// Register graph
dc::register_graph_z3(0, param_ids);

// Mark first parameter as persistent
dc::set_persistent(0);

// Allgather parameter
at::Tensor gathered = dc::allgather_param(
    param_shards[0],
    0,                              // graph_id
    0,                              // ds_id
    std::optional<at::ScalarType>(at::kFloat16)
);
gathered = dc::wait_allgather(gathered, 0, 0);

// Prefetch multiple parameters
std::vector<at::Tensor> param_list = {param_shards[1], param_shards[2]};
std::vector<long> id_list = {1, 2};
std::vector<at::ScalarType> dtype_list = {at::kFloat16, at::kBFloat16};
dc::prefetch_params_fused(0, param_list, id_list, dtype_list);

// Release parameter after use
at::Tensor dummy = torch::empty({0});
dummy = dc::release_param(dummy, 0, 0, 3);  // 3 users

// Offload/reload activations
at::Tensor activation = torch::randn({1024, 1024}, at::kCUDA);
at::Tensor offloaded = dc::offload_tensor(activation, 0, 100);
offloaded = dc::wait_offload(offloaded, 0, 100);

at::Tensor reloaded = dc::reload_tensor(offloaded, 0, 100);
reloaded = dc::wait_reload(reloaded, 0, 100);

// End backward pass
dc::end_backward(0);
# Python usage via custom operators
import torch
from deepspeed.ops import dc

# Register parameters (Python binding)
dc.register_z3_param(0, [1024, 768], param_shard, grad_buffer, False)
dc.register_graph_z3(0, [0, 1, 2])

# Use custom operators in torch.compile'd code
@torch.compile
def forward_with_z3_ops(shard, graph_id, param_id):
    # Allgather
    param = torch.ops.dc.allgather_param(shard, graph_id, param_id, torch.float16)
    param = torch.ops.dc.wait_allgather(param, graph_id, param_id)

    # Use parameter
    output = torch.matmul(input, param.T)

    # Release
    dummy = torch.empty(0)
    dummy = torch.ops.dc.release_param(dummy, graph_id, param_id, 1)

    return output

# Prefetch multiple parameters
torch.ops.dc.prefetch_params_fused(
    graph_id=0,
    params=[p1, p2],
    ids=[1, 2],
    dtypes=[torch.float16, torch.bfloat16]
)

# End backward
torch.ops.dc.end_backward(graph_id=0)

API Design Principles

Dual Interface

Functions are designed to be called from both: 1. C++ implementations (z3.cpp calls into Z3CustomOpExecutor methods) 2. Python bindings (exposed via pybind11 in init.cpp)

Optional Parameters

Use std::optional for nullable parameters:

  • dtype parameter allows None to preserve original dtype
  • dtypes list allows per-parameter dtype specification

Meta Variants

All tensor-returning operations have _meta variants:

  • Enable torch.compile symbolic tracing without actual execution
  • Return tensors with correct shape/dtype but no data
  • Critical for graph capture in profile mode

Graph ID Scoping

Most operations require graph_id parameter:

  • Allows multiple independent compiled graphs
  • Each graph has its own Z3CustomOpExecutor instance
  • Enables modular compilation of model components

Dummy Tensor Pattern

release_param takes and returns dummy tensor:

  • Enables dependency tracking in computation graph
  • Ensures proper execution order without actual data flow
  • Common pattern for side-effect operations in graph mode

Relationship to Implementation

Header Function Implementation Description
register_z3_param z3.cpp:register_z3_param Validates shard sizes, registers in param_registry
register_graph_z3 z3.cpp:register_graph_z3 Creates Z3CustomOpExecutor instance
allgather_param z3.cpp:allgather_param -> Z3CustomOpExecutor::allgatherParam Performs NCCL allgather or symmetric memory gather
prefetch_params_fused z3.cpp:prefetch_params_fused -> Z3CustomOpExecutor::prefetchParamsFused Batches allgathers with ncclGroupStart/End
release_param z3.cpp:release_param -> Z3CustomOpExecutor::releaseParam Reference counts and frees non-persistent params
wait_allgather z3.cpp:wait_allgather -> Z3CustomOpExecutor::waitAllgather Blocks current stream on allgather completion
offload_tensor z3.cpp:offload_tensor -> Z3CustomOpExecutor::offloadTensor Async copy to pinned CPU memory
reload_tensor z3.cpp:reload_tensor -> Z3CustomOpExecutor::reloadTensor Async copy back to GPU

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment