Implementation:Deepspeedai DeepSpeed ZeRO3 API Header
| Knowledge Sources | |
|---|---|
| Domains | Graph_Compilation, ZeRO_Optimization, API_Interface |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
ZeRO3 API Header declares the public interface for ZeRO Stage 3 DeepCompile operations including parameter allgather, prefetch, release, offload, and lifecycle management.
Description
The z3.h header file provides the complete public API surface for ZeRO Stage 3 operations within DeepSpeed's graph compilation framework. It declares:
- Parameter Management: Functions for registering ZeRO-3 partitioned parameters and graph executors
- Allgather Operations: Both single-parameter (allgather_param) and batched (prefetch_params_fused) parameter gathering
- Memory Management: Parameter release, persistence control, and memory invalidation for profiling
- Synchronization: Wait operations for allgather completion
- Offload/Reload: CPU offload and reload operations for both activations and parameters
- Lifecycle Hooks: End-of-backward coordination
- Meta Variants: Shape-only implementations for torch.compile's symbolic tracing
All functions are declared within the dc namespace and designed to be called from both C++ implementations and Python bindings via pybind11.
Usage
This header is included by z3.cpp (implementation), init.cpp (registration), and potentially user code needing direct C++ access to ZeRO-3 operations.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/compile/z3.h
Signature
#pragma once
#include "deepcompile.h"
namespace dc {
// Registration functions
void register_graph_z3(long graph_id, const std::vector<long>& ds_ids);
void register_graph_ops_z3(long graph_id,
const std::vector<std::string>& op_names,
const std::vector<long>& n_args);
void register_bwd_graph_ops_z3(long graph_id,
const std::vector<std::string>& op_names,
const std::vector<long>& n_args);
void register_z3_param(long ds_id,
const std::vector<int64_t>& ds_shape,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
bool persistent);
// Parameter operations
at::Tensor allgather_param(at::Tensor param_tensor,
long graph_id,
long ds_id,
std::optional<at::ScalarType> dtype);
void set_persistent(long ds_id);
void prefetch_params_fused(long graph_id,
const std::vector<at::Tensor>& params,
const std::vector<long>& ds_ids,
const std::optional<std::vector<at::ScalarType>>& dtypes);
at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users);
at::Tensor wait_allgather(at::Tensor v, long graph_id, const long ds_id);
// Profiling support
void invalidate_gathered_param(long ds_id);
void clear_all_gathered_params();
// Meta variants (shape inference only)
at::Tensor allgather_param_meta(at::Tensor param_tensor,
long graph_id,
long ds_id,
std::optional<at::ScalarType> dtype);
void prefetch_params_fused_meta(long graph_id,
const std::vector<at::Tensor>& params,
const std::vector<long>& ds_ids,
const std::optional<std::vector<at::ScalarType>>& dtypes);
at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users);
at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id);
// Offload/Reload operations
at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id);
at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id);
at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id);
at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id);
void reload_parameter(at::Tensor tensor, long graph_id, long id);
void offload_parameter(at::Tensor tensor, long graph_id, long id);
void reload_parameter_meta(at::Tensor tensor, long graph_id, long id);
void offload_parameter_meta(at::Tensor tensor, long graph_id, long id);
// Lifecycle
void end_backward(long graph_id);
} // namespace dc
Import
// C++ usage
#include "z3.h"
// Register parameter
dc::register_z3_param(0, {1024, 768}, param_shard, grad_buffer, false);
// Register graph
dc::register_graph_z3(0, {0, 1, 2});
// Allgather in C++
at::Tensor gathered = dc::allgather_param(shard, 0, 0, std::nullopt);
# Python usage (via pybind11 bindings)
from deepspeed.ops import dc
# Register parameter
dc.register_z3_param(0, [1024, 768], param_shard, grad_buffer, False)
# Register graph
dc.register_graph_z3(0, [0, 1, 2])
# Allgather via custom operator
gathered = torch.ops.dc.allgather_param(shard, 0, 0, None)
I/O Contract
Registration Functions
| Function | Parameters | Description |
|---|---|---|
| register_graph_z3 | (long, vector<long>) | Creates Z3CustomOpExecutor for graph with parameter IDs |
| register_graph_ops_z3 | (long, vector<string>, vector<long>) | Registers forward graph operations (unused in current impl) |
| register_bwd_graph_ops_z3 | (long, vector<string>, vector<long>) | Registers backward graph operations (unused in current impl) |
| register_z3_param | (long, vector<int64_t>, Tensor, Tensor, bool) | Registers partitioned parameter with validation |
Parameter Operations
| Function | Parameters | Returns | Description |
|---|---|---|---|
| allgather_param | (Tensor, long, long, optional<ScalarType>) | Tensor | Gathers partitioned parameter from all ranks |
| set_persistent | (long) | void | Marks parameter as persistent (keep gathered) |
| prefetch_params_fused | (long, vector<Tensor>, vector<long>, optional<vector<ScalarType>>) | void | Batches multiple allgathers |
| release_param | (Tensor, long, long, long) | Tensor | Decrements use count, frees non-persistent params |
| wait_allgather | (Tensor, long, long) | Tensor | Synchronizes allgather stream |
Profiling Functions
| Function | Parameters | Description |
|---|---|---|
| invalidate_gathered_param | (long) | Invalidates gathered param for single parameter (profiling) |
| clear_all_gathered_params | () | Clears all non-persistent gathered parameters |
Meta Variants
| Function | Purpose |
|---|---|
| allgather_param_meta | Returns empty tensor with correct shape/dtype for symbolic tracing |
| prefetch_params_fused_meta | No-op for symbolic tracing |
| release_param_meta | Pass-through for symbolic tracing |
| wait_allgather_meta | Pass-through for symbolic tracing |
| reload_parameter_meta | No-op for symbolic tracing |
| offload_parameter_meta | No-op for symbolic tracing |
Offload/Reload Operations
| Function | Parameters | Returns | Description |
|---|---|---|---|
| offload_tensor | (Tensor, long, long) | Tensor | Offloads activation to CPU pinned memory |
| reload_tensor | (Tensor, long, long) | Tensor | Reloads activation from CPU to GPU |
| wait_offload | (Tensor, long, long) | Tensor | Waits for offload completion |
| wait_reload | (Tensor, long, long) | Tensor | Waits for reload completion |
| offload_parameter | (Tensor, long, long) | void | Offloads parameter to CPU |
| reload_parameter | (Tensor, long, long) | void | Reloads parameter from CPU |
Lifecycle
| Function | Parameters | Description |
|---|---|---|
| end_backward | (long) | Flushes reduce buckets, synchronizes streams |
Usage Examples
// C++ usage example
#include "z3.h"
#include <torch/torch.h>
// Register ZeRO-3 parameters
std::vector<long> param_ids;
for (int i = 0; i < num_params; i++) {
dc::register_z3_param(
i, // ds_id
{1024, 768}, // shape
param_shards[i], // partitioned tensor
grad_buffers[i], // gradient buffer
false // persistent
);
param_ids.push_back(i);
}
// Register graph
dc::register_graph_z3(0, param_ids);
// Mark first parameter as persistent
dc::set_persistent(0);
// Allgather parameter
at::Tensor gathered = dc::allgather_param(
param_shards[0],
0, // graph_id
0, // ds_id
std::optional<at::ScalarType>(at::kFloat16)
);
gathered = dc::wait_allgather(gathered, 0, 0);
// Prefetch multiple parameters
std::vector<at::Tensor> param_list = {param_shards[1], param_shards[2]};
std::vector<long> id_list = {1, 2};
std::vector<at::ScalarType> dtype_list = {at::kFloat16, at::kBFloat16};
dc::prefetch_params_fused(0, param_list, id_list, dtype_list);
// Release parameter after use
at::Tensor dummy = torch::empty({0});
dummy = dc::release_param(dummy, 0, 0, 3); // 3 users
// Offload/reload activations
at::Tensor activation = torch::randn({1024, 1024}, at::kCUDA);
at::Tensor offloaded = dc::offload_tensor(activation, 0, 100);
offloaded = dc::wait_offload(offloaded, 0, 100);
at::Tensor reloaded = dc::reload_tensor(offloaded, 0, 100);
reloaded = dc::wait_reload(reloaded, 0, 100);
// End backward pass
dc::end_backward(0);
# Python usage via custom operators
import torch
from deepspeed.ops import dc
# Register parameters (Python binding)
dc.register_z3_param(0, [1024, 768], param_shard, grad_buffer, False)
dc.register_graph_z3(0, [0, 1, 2])
# Use custom operators in torch.compile'd code
@torch.compile
def forward_with_z3_ops(shard, graph_id, param_id):
# Allgather
param = torch.ops.dc.allgather_param(shard, graph_id, param_id, torch.float16)
param = torch.ops.dc.wait_allgather(param, graph_id, param_id)
# Use parameter
output = torch.matmul(input, param.T)
# Release
dummy = torch.empty(0)
dummy = torch.ops.dc.release_param(dummy, graph_id, param_id, 1)
return output
# Prefetch multiple parameters
torch.ops.dc.prefetch_params_fused(
graph_id=0,
params=[p1, p2],
ids=[1, 2],
dtypes=[torch.float16, torch.bfloat16]
)
# End backward
torch.ops.dc.end_backward(graph_id=0)
API Design Principles
Dual Interface
Functions are designed to be called from both: 1. C++ implementations (z3.cpp calls into Z3CustomOpExecutor methods) 2. Python bindings (exposed via pybind11 in init.cpp)
Optional Parameters
Use std::optional for nullable parameters:
- dtype parameter allows None to preserve original dtype
- dtypes list allows per-parameter dtype specification
Meta Variants
All tensor-returning operations have _meta variants:
- Enable torch.compile symbolic tracing without actual execution
- Return tensors with correct shape/dtype but no data
- Critical for graph capture in profile mode
Graph ID Scoping
Most operations require graph_id parameter:
- Allows multiple independent compiled graphs
- Each graph has its own Z3CustomOpExecutor instance
- Enables modular compilation of model components
Dummy Tensor Pattern
release_param takes and returns dummy tensor:
- Enables dependency tracking in computation graph
- Ensures proper execution order without actual data flow
- Common pattern for side-effect operations in graph mode
Relationship to Implementation
| Header Function | Implementation | Description |
|---|---|---|
| register_z3_param | z3.cpp:register_z3_param | Validates shard sizes, registers in param_registry |
| register_graph_z3 | z3.cpp:register_graph_z3 | Creates Z3CustomOpExecutor instance |
| allgather_param | z3.cpp:allgather_param -> Z3CustomOpExecutor::allgatherParam | Performs NCCL allgather or symmetric memory gather |
| prefetch_params_fused | z3.cpp:prefetch_params_fused -> Z3CustomOpExecutor::prefetchParamsFused | Batches allgathers with ncclGroupStart/End |
| release_param | z3.cpp:release_param -> Z3CustomOpExecutor::releaseParam | Reference counts and frees non-persistent params |
| wait_allgather | z3.cpp:wait_allgather -> Z3CustomOpExecutor::waitAllgather | Blocks current stream on allgather completion |
| offload_tensor | z3.cpp:offload_tensor -> Z3CustomOpExecutor::offloadTensor | Async copy to pinned CPU memory |
| reload_tensor | z3.cpp:reload_tensor -> Z3CustomOpExecutor::reloadTensor | Async copy back to GPU |