Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed ZeRO3 DeepCompile

From Leeroopedia


Knowledge Sources
Domains Graph_Compilation, ZeRO_Optimization, Distributed_Training
Last Updated 2026-02-09 00:00 GMT

Overview

ZeRO Stage 3 DeepCompile implements torch.compile-compatible graph optimizations for parameter partitioning with allgather, reduce-scatter, prefetching, and CPU offloading capabilities.

Description

The Z3CustomOpExecutor extends CustomOpExecutor to provide ZeRO Stage 3 functionality within DeepSpeed's graph compilation framework. It manages distributed parameter gathering through NCCL allgather operations, handles gradient reduction via reduce-scatter, and supports advanced features including:

  • Allgather Operations: Fetches partitioned parameters from all ranks with uniform shard size validation and optional symmetric memory support
  • Prefetch Optimization: Batches multiple parameter allgathers using ncclGroupStart/ncclGroupEnd for reduced communication overhead
  • Parameter Lifecycle: Tracks usage counts to release non-persistent gathered parameters and reclaim memory
  • CPU Offloading: Asynchronously moves activations to pinned CPU memory with stream synchronization
  • Gradient Accumulation: Uses temporary buffers during reduce-scatter to handle accumulated gradients correctly
  • Stream Management: Coordinates five specialized CUDA streams (allgather, reduce-scatter, copy, offload, reload) with event-based synchronization

The implementation supports both standard NCCL collectives and symmetric memory for allgather, validates padded shard sizes across ranks during registration, and manages persistent vs. non-persistent parameters with different memory lifetimes.

Usage

Z3CustomOpExecutor is instantiated per compiled graph during registration and invoked through torch custom operators (torch.ops.dc.*) inserted by the DeepCompile graph rewriting pass. Users register ZeRO-3 partitioned parameters and graphs before forward/backward execution.

Code Reference

Source Location

Signature

class Z3CustomOpExecutor : public CustomOpExecutor {
public:
    Z3CustomOpExecutor(c10::intrusive_ptr<c10d::ProcessGroup> process_group,
                       std::shared_ptr<DSParamRegistry> param_registry,
                       std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets,
                       std::vector<long> ds_ids,
                       ncclComm_t nccl_comm,
                       at::cuda::CUDAStream ag_stream,
                       at::cuda::CUDAStream rs_stream,
                       at::cuda::CUDAStream copy_stream,
                       at::cuda::CUDAStream offload_stream,
                       at::cuda::CUDAStream reload_stream,
                       bool pre_div_reduce);

    at::Tensor allgatherParam(long ds_id,
                              std::optional<at::ScalarType> dtype,
                              c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem);

    void prefetchParamsFused(const std::vector<long>& ds_ids,
                             const std::optional<std::vector<at::ScalarType>> dtypes,
                             c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem);

    void releaseParam(long ds_id, long n_users);
    void flushReduceBucket(at::ScalarType scalar_type) override;
    at::Tensor offloadTensor(at::Tensor tensor, long id);
    at::Tensor reloadTensor(at::Tensor tensor, long id);
};

// Public API functions
void register_graph_z3(long graph_id, const std::vector<long>& ds_ids);
void register_z3_param(long ds_id,
                       const std::vector<int64_t>& ds_shape,
                       at::Tensor ds_tensor,
                       at::Tensor grad_buffer,
                       bool persistent);

at::Tensor allgather_param(at::Tensor param_tensor,
                           long graph_id,
                           long ds_id,
                           std::optional<at::ScalarType> dtype);

void prefetch_params_fused(long graph_id,
                           const std::vector<at::Tensor>& params,
                           const std::vector<long>& ds_ids,
                           const std::optional<std::vector<at::ScalarType>>& dtypes);

at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users);
at::Tensor wait_allgather(at::Tensor v, long graph_id, long ds_id);

Import

import torch
from deepspeed.ops import dc

# Register ZeRO-3 parameter
dc.register_z3_param(
    ds_id=param_id,
    ds_shape=param.shape,
    ds_tensor=param_shard,
    grad_buffer=grad_buffer,
    persistent=False
)

# Register graph with parameter IDs
dc.register_graph_z3(graph_id=0, ds_ids=[0, 1, 2])

# Use in torch.compile'd code via custom ops
gathered = torch.ops.dc.allgather_param(param_shard, graph_id=0, ds_id=0, dtype=None)
gathered = torch.ops.dc.wait_allgather(gathered, graph_id=0, ds_id=0)

# Prefetch multiple parameters
torch.ops.dc.prefetch_params_fused(
    graph_id=0,
    params=[p1, p2, p3],
    ids=[0, 1, 2],
    dtypes=None
)

# Release parameter memory
dummy = torch.ops.dc.release_param(dummy, graph_id=0, ds_id=0, n_users=3)

I/O Contract

allgatherParam

Parameter Type Description
ds_id long DeepSpeed parameter identifier
dtype std::optional<at::ScalarType> Target dtype for gathered parameter (None uses original)
symm_mem c10::intrusive_ptr<SymmetricMemory> Symmetric memory workspace (None uses NCCL)
Returns at::Tensor Gathered parameter tensor with true shape (padded internally)

prefetchParamsFused

Parameter Type Description
ds_ids std::vector<long> List of parameter IDs to prefetch
dtypes std::optional<std::vector<at::ScalarType>> Target dtypes for each parameter
symm_mem c10::intrusive_ptr<SymmetricMemory> Symmetric memory workspace
Returns void Launches batched allgather in background

releaseParam

Parameter Type Description
ds_id long Parameter ID to release
n_users long Total number of uses for this parameter in graph
Effect - Decrements use count; frees non-persistent param when count reaches 0

flushReduceBucket

Parameter Type Description
scalar_type at::ScalarType Data type of gradients to reduce
Effect - Performs NCCL reduce-scatter with gradient accumulation support

offloadTensor / reloadTensor

Parameter Type Description
tensor at::Tensor Tensor to offload/reload
id long Unique identifier for offload buffer
Returns at::Tensor Pinned CPU tensor (offload) or reloaded GPU tensor (reload)

Usage Examples

import torch
import torch.distributed as dist
from deepspeed.ops import dc

# Initialize DeepSpeed DeepCompile
config = {
    "symmetric_memory": False,
    "double_buffer": True,
    "free_activation_threshold": 1024 * 1024,
    "sync_before_reduce": False,
    "sync_after_reduce": False,
    "sync_before_allgather": False,
    "sync_after_allgather": False,
}
process_group = dist.new_group(ranks=list(range(dist.get_world_size())))
dc.init(process_group, config, initial_reduce_bucket_size=100_000_000)

# Register ZeRO-3 partitioned parameters
for param_id, (param_shard, grad_buffer) in enumerate(model_params):
    dc.register_z3_param(
        ds_id=param_id,
        ds_shape=original_param_shape,
        ds_tensor=param_shard,  # Partitioned tensor on this rank
        grad_buffer=grad_buffer,
        persistent=False  # Set True for frequently used params
    )

# Register compiled graph
dc.register_graph_z3(graph_id=0, ds_ids=list(range(len(model_params))))

# Mark persistent parameters (stays gathered in memory)
dc.set_persistent(ds_id=0)  # First layer kept resident

# Compiled forward/backward with DeepCompile ops
@torch.compile(backend="deepspeed")
def training_step(inputs):
    dc.start_forward()

    # Prefetch parameters before module execution
    torch.ops.dc.prefetch_params_fused(
        graph_id=0,
        params=[torch.empty(0) for _ in range(3)],
        ids=[0, 1, 2],
        dtypes=[torch.float16, torch.float16, torch.bfloat16]
    )

    # Allgather parameter for computation
    param_0 = torch.ops.dc.allgather_param(
        param_shards[0],
        graph_id=0,
        ds_id=0,
        dtype=torch.float16
    )
    param_0 = torch.ops.dc.wait_allgather(param_0, graph_id=0, ds_id=0)

    # Forward computation
    output = model(inputs, param_0)

    # Release parameter memory after use
    dummy = torch.empty(0)
    dummy = torch.ops.dc.release_param(dummy, graph_id=0, ds_id=0, n_users=1)

    dc.end_forward()

    # Backward pass
    dc.start_backward(update=True)
    loss = output.sum()
    loss.backward()

    # Gradient reduce happens automatically via hooks
    torch.ops.dc.end_backward(graph_id=0)

    return loss

# Optional: Offload activations to CPU
activation = torch.randn(1024, 1024, device='cuda')
offloaded = torch.ops.dc.offload_tensor(activation, graph_id=0, id=100)
offloaded = torch.ops.dc.wait_offload(offloaded, graph_id=0, id=100)

# Reload when needed
reloaded = torch.ops.dc.reload_tensor(offloaded, graph_id=0, id=100)
reloaded = torch.ops.dc.wait_reload(reloaded, graph_id=0, id=100)

# Cleanup
dc.cleanup()

Implementation Details

Allgather Fast Path

The implementation uses a fast path for uniform shard sizes (standard ZeRO-3 partitioning):

  • Assumes all ranks have equal-sized shards (padded to uniform size)
  • Performs direct ncclAllGather into pre-allocated output buffer
  • Validates shard uniformity at parameter registration with allgather verification

Symmetric Memory Path

When symmetric memory is enabled:

  • Uses get_buffer() to access remote rank memory directly
  • Performs explicit barriers before and after remote copies
  • Copies each rank's shard in round-robin order
  • Provides alternative to NCCL for specific hardware configurations

Gradient Accumulation

The reduce-scatter implementation handles gradient accumulation correctly:

  • Allocates temporary buffer for accumulated gradients
  • Performs reduce-scatter into temporary buffer for accumulated parameters
  • Adds temporary results to existing gradient buffers
  • Prevents overwriting accumulated gradients during reduce-scatter

Stream Synchronization

Five specialized streams coordinate operations:

  • ag_stream: Allgather communication
  • rs_stream: Reduce-scatter communication
  • copy_stream: Asynchronous tensor copies
  • offload_stream: CPU offload operations
  • reload_stream: GPU reload operations

Events recorded on streams enable proper dependency management without blocking the default stream unnecessarily.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment