Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed DeepCompile Init

From Leeroopedia


Knowledge Sources
Domains Graph_Compilation, PyTorch_Extension, Operator_Registration
Last Updated 2026-02-09 00:00 GMT

Overview

DeepCompile Init registers custom PyTorch operators and Python bindings for all ZeRO stages, enabling torch.compile integration with DeepSpeed optimizations.

Description

The init.cpp file uses PyTorch's TORCH_LIBRARY system to register custom operators in the "dc" namespace and provide dispatch implementations for CPU, CUDA, and Meta backends. It includes:

  • Operator Definitions: Declares 13 custom operators including allgather_param, prefetch_params_fused, reduce_grad, wait_allgather, release_param, offload/reload tensor operations, and lifecycle hooks
  • Multi-Backend Dispatch: Registers implementations for CPU, CUDA, and Meta (shape inference) backends via TORCH_LIBRARY_IMPL
  • Meta Kernel Support: Provides shape-only implementations for torch.compile's symbolic tracing without executing communication
  • Python Bindings: Exposes registration functions and lifecycle management via pybind11 (PYBIND11_MODULE)
  • Lifecycle Management: Registers end_backward in the "Undefined" dispatch key since it operates without tensor arguments
  • ZeRO Stage Integration: Connects Python registration calls to C++ implementations in z1.cpp, z2.cpp, and z3.cpp

The operator schema definitions use PyTorch's custom operator syntax with aliasing annotations (e.g., Tensor(a)) to indicate in-place behavior and help the compiler understand memory dependencies.

Usage

This file is compiled into a PyTorch extension module that Python imports as deepspeed.ops.dc, making custom operators available via torch.ops.dc.* and providing registration functions accessible from Python.

Code Reference

Source Location

Signature

// Operator registrations in TORCH_LIBRARY(dc, m)
TORCH_LIBRARY(dc, m) {
    m.def("allgather_param(Tensor a, int graph_id, int id, ScalarType? dtype = None) -> Tensor");
    m.def("prefetch_params_fused(int graph_id, Tensor[] params, int[] ids, ScalarType[]? dtypes = None) -> ()");
    m.def("wait_allgather(Tensor(a) a, int graph_id, int id) -> Tensor(a)");
    m.def("release_param(Tensor(a) a, int graph_id, int id, int n_users) -> Tensor(a)");
    m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor");
    m.def("free_tensors(Tensor[] a) -> ()");
    m.def("offload_tensor(Tensor a, int id, int id) -> Tensor");
    m.def("reload_tensor(Tensor a, int id, int id) -> Tensor");
    m.def("wait_offload(Tensor a, int id, int id) -> Tensor");
    m.def("wait_reload(Tensor a, int id, int id) -> Tensor");
    m.def("offload_parameter(Tensor a, int id, int id) -> ()");
    m.def("reload_parameter(Tensor a, int id, int id) -> ()");
    m.def("end_backward(int graph_id) -> ()");
}

// Backend implementations
TORCH_LIBRARY_IMPL(dc, CUDA, m);   // Maps to z1/z2/z3 implementations
TORCH_LIBRARY_IMPL(dc, CPU, m);    // Same implementations (handled via ProcessGroup)
TORCH_LIBRARY_IMPL(dc, Meta, m);   // Shape inference only (_meta variants)

// Python bindings in PYBIND11_MODULE
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("init", &dc::init);
    m.def("cleanup", &dc::cleanup);
    m.def("register_param", &dc::register_param);           // Z1
    m.def("register_graph_z1", &dc::register_graph_z1);     // Z1
    m.def("register_graph_z2", &dc::register_graph_z2);     // Z2
    m.def("register_z3_param", &dc::register_z3_param);     // Z3
    m.def("register_graph_z3", &dc::register_graph_z3);     // Z3
    m.def("set_persistent", &dc::set_persistent);           // Z3
    m.def("start_forward", &dc::start_forward);
    m.def("end_forward", &dc::end_forward);
    m.def("start_backward", &dc::start_backward);
    m.def("enable_profiling", &dc::enable_profiling);
    m.def("is_profiling", &dc::is_profiling);
    m.def("reset", &dc::reset);
    m.def("invalidate_gathered_param", &dc::invalidate_gathered_param);  // Z3
    m.def("clear_all_gathered_params", &dc::clear_all_gathered_params);  // Z3
}

Import

import torch
from deepspeed.ops import dc  # Imports compiled extension

# Use registered custom operators
result = torch.ops.dc.allgather_param(tensor, graph_id=0, id=0, dtype=None)
result = torch.ops.dc.wait_allgather(result, graph_id=0, id=0)

torch.ops.dc.prefetch_params_fused(
    graph_id=0,
    params=[p1, p2],
    ids=[0, 1],
    dtypes=[torch.float16, torch.bfloat16]
)

# Reduce gradients
torch.ops.dc.reduce_grad(grad_tensor, graph_id=0, id=0)

# Lifecycle hooks
torch.ops.dc.end_backward(graph_id=0)

# Use Python bindings for registration
dc.init(process_group, config, bucket_size)
dc.register_graph_z3(graph_id=0, ds_ids=[0, 1, 2])
dc.register_z3_param(ds_id=0, ds_shape=[1024, 768], ds_tensor=shard, grad_buffer=grad)

I/O Contract

Custom Operators

Operator Schema Description
allgather_param (Tensor, int, int, ScalarType?) -> Tensor Gathers partitioned parameter from all ranks
prefetch_params_fused (int, Tensor[], int[], ScalarType[]?) -> () Batches multiple allgathers
wait_allgather (Tensor(a), int, int) -> Tensor(a) Synchronizes allgather stream
release_param (Tensor(a), int, int, int) -> Tensor(a) Releases gathered parameter memory
reduce_grad (Tensor, int, int) -> Tensor Reduces gradient across ranks
free_tensors (Tensor[]) -> () Frees large activation tensors
offload_tensor (Tensor, int, int) -> Tensor Offloads tensor to CPU pinned memory
reload_tensor (Tensor, int, int) -> Tensor Reloads tensor from CPU to GPU
wait_offload (Tensor, int, int) -> Tensor Waits for offload completion
wait_reload (Tensor, int, int) -> Tensor Waits for reload completion
offload_parameter (Tensor, int, int) -> () Offloads parameter to CPU
reload_parameter (Tensor, int, int) -> () Reloads parameter from CPU
end_backward (int) -> () Flushes reduce buckets at end of backward

Python Bindings

Function Parameters Description
init (ProcessGroup, config, int64_t) Initializes runtime with process group
cleanup () Destroys NCCL communicator and clears state
register_param (int, int64_t[], Tensor, Tensor, int64_t) Registers Z1 parameter
register_graph_z1 (int, int[]) Registers Z1 graph executor
register_graph_z2 (int, int[]) Registers Z2 graph executor
register_z3_param (int, int64_t[], Tensor, Tensor, bool) Registers Z3 partitioned parameter
register_graph_z3 (int, int[]) Registers Z3 graph executor
set_persistent (int) Marks Z3 parameter as persistent (keep gathered)
start_forward () Begins forward pass lifecycle
end_forward () Ends forward pass lifecycle
start_backward (bool) Begins backward pass with update flag
enable_profiling (bool) Enables/disables profiling mode
is_profiling () -> bool Checks if profiling enabled
reset () Clears executors
invalidate_gathered_param (int) Invalidates Z3 gathered parameter (profiling)
clear_all_gathered_params () Clears all Z3 gathered parameters

Usage Examples

import torch
import torch.distributed as dist
from deepspeed.ops import dc

# Initialize extension
dist.init_process_group("nccl")
pg = dist.new_group()
config = {"symmetric_memory": False, "double_buffer": True, "free_activation_threshold": 1_000_000,
          "sync_before_reduce": False, "sync_after_reduce": False,
          "sync_before_allgather": False, "sync_after_allgather": False}
dc.init(pg, config, initial_reduce_bucket_size=100_000_000)

# Register ZeRO-3 parameters
for i, (shard, grad_buf) in enumerate(param_shards):
    dc.register_z3_param(
        ds_id=i,
        ds_shape=original_shape,
        ds_tensor=shard,
        grad_buffer=grad_buf,
        persistent=False
    )

dc.register_graph_z3(graph_id=0, ds_ids=list(range(len(param_shards))))

# Using custom operators in compiled model
@torch.compile(backend="inductor")
def forward_with_allgather(x, param_shard, graph_id, param_id):
    # Allgather parameter
    param = torch.ops.dc.allgather_param(param_shard, graph_id, param_id, dtype=torch.float16)
    param = torch.ops.dc.wait_allgather(param, graph_id, param_id)

    # Use parameter
    output = torch.matmul(x, param.T)

    # Release memory
    dummy = torch.empty(0)
    dummy = torch.ops.dc.release_param(dummy, graph_id, param_id, n_users=1)

    return output

# Training loop with lifecycle management
dc.start_forward()
output = forward_with_allgather(input, param_shards[0], 0, 0)
dc.end_forward()

dc.start_backward(update=True)
loss = output.sum()
loss.backward()
torch.ops.dc.end_backward(graph_id=0)

# Profiling mode for shape inference
dc.enable_profiling(True)
traced_graph = torch.compile(forward_with_allgather, backend="eager")
_ = traced_graph(sample_input, param_shards[0], 0, 0)
dc.enable_profiling(False)

# Activation memory management
large_activations = [act for act in activations if act.numel() > 1_000_000]
torch.ops.dc.free_tensors(large_activations)

# Cleanup
dc.cleanup()

Implementation Details

Operator Aliasing Annotations

The schema uses PyTorch's aliasing notation:

  • Tensor(a) indicates the output aliases the input (shares underlying memory)
  • Used for wait_allgather, release_param to signal no-copy operations
  • Helps torch.compile understand memory dependencies and optimize

Meta Backend Implementation

Meta kernels provide shape inference without execution:

  • allgather_param_meta returns empty tensor with correct shape and dtype
  • reduce_grad_meta returns empty tensor
  • Enables symbolic tracing and graph capture in profiling mode

Dispatch Key Selection

  • CPU/CUDA dispatch keys route to actual implementations (z1/z2/z3.cpp)
  • Meta dispatch key routes to _meta variants for shape inference
  • Undefined dispatch key used for end_backward (no tensor arguments)

Backend Consistency

Both CPU and CUDA implementations call the same functions because:

  • ProcessGroup handles device abstraction
  • NCCL operations work on CUDA tensors regardless of dispatch backend
  • CPU dispatch supports CPU-based testing and debugging

Extension Module Loading

PYBIND11_MODULE creates a Python extension module:

  • Module name determined by TORCH_EXTENSION_NAME macro
  • Compiled as shared library (.so on Linux, .pyd on Windows)
  • Imported in Python as regular module with custom ops auto-registered

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment