Implementation:Deepspeedai DeepSpeed DeepCompile Init
| Knowledge Sources | |
|---|---|
| Domains | Graph_Compilation, PyTorch_Extension, Operator_Registration |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
DeepCompile Init registers custom PyTorch operators and Python bindings for all ZeRO stages, enabling torch.compile integration with DeepSpeed optimizations.
Description
The init.cpp file uses PyTorch's TORCH_LIBRARY system to register custom operators in the "dc" namespace and provide dispatch implementations for CPU, CUDA, and Meta backends. It includes:
- Operator Definitions: Declares 13 custom operators including allgather_param, prefetch_params_fused, reduce_grad, wait_allgather, release_param, offload/reload tensor operations, and lifecycle hooks
- Multi-Backend Dispatch: Registers implementations for CPU, CUDA, and Meta (shape inference) backends via TORCH_LIBRARY_IMPL
- Meta Kernel Support: Provides shape-only implementations for torch.compile's symbolic tracing without executing communication
- Python Bindings: Exposes registration functions and lifecycle management via pybind11 (PYBIND11_MODULE)
- Lifecycle Management: Registers end_backward in the "Undefined" dispatch key since it operates without tensor arguments
- ZeRO Stage Integration: Connects Python registration calls to C++ implementations in z1.cpp, z2.cpp, and z3.cpp
The operator schema definitions use PyTorch's custom operator syntax with aliasing annotations (e.g., Tensor(a)) to indicate in-place behavior and help the compiler understand memory dependencies.
Usage
This file is compiled into a PyTorch extension module that Python imports as deepspeed.ops.dc, making custom operators available via torch.ops.dc.* and providing registration functions accessible from Python.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/compile/init.cpp
Signature
// Operator registrations in TORCH_LIBRARY(dc, m)
TORCH_LIBRARY(dc, m) {
m.def("allgather_param(Tensor a, int graph_id, int id, ScalarType? dtype = None) -> Tensor");
m.def("prefetch_params_fused(int graph_id, Tensor[] params, int[] ids, ScalarType[]? dtypes = None) -> ()");
m.def("wait_allgather(Tensor(a) a, int graph_id, int id) -> Tensor(a)");
m.def("release_param(Tensor(a) a, int graph_id, int id, int n_users) -> Tensor(a)");
m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor");
m.def("free_tensors(Tensor[] a) -> ()");
m.def("offload_tensor(Tensor a, int id, int id) -> Tensor");
m.def("reload_tensor(Tensor a, int id, int id) -> Tensor");
m.def("wait_offload(Tensor a, int id, int id) -> Tensor");
m.def("wait_reload(Tensor a, int id, int id) -> Tensor");
m.def("offload_parameter(Tensor a, int id, int id) -> ()");
m.def("reload_parameter(Tensor a, int id, int id) -> ()");
m.def("end_backward(int graph_id) -> ()");
}
// Backend implementations
TORCH_LIBRARY_IMPL(dc, CUDA, m); // Maps to z1/z2/z3 implementations
TORCH_LIBRARY_IMPL(dc, CPU, m); // Same implementations (handled via ProcessGroup)
TORCH_LIBRARY_IMPL(dc, Meta, m); // Shape inference only (_meta variants)
// Python bindings in PYBIND11_MODULE
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("init", &dc::init);
m.def("cleanup", &dc::cleanup);
m.def("register_param", &dc::register_param); // Z1
m.def("register_graph_z1", &dc::register_graph_z1); // Z1
m.def("register_graph_z2", &dc::register_graph_z2); // Z2
m.def("register_z3_param", &dc::register_z3_param); // Z3
m.def("register_graph_z3", &dc::register_graph_z3); // Z3
m.def("set_persistent", &dc::set_persistent); // Z3
m.def("start_forward", &dc::start_forward);
m.def("end_forward", &dc::end_forward);
m.def("start_backward", &dc::start_backward);
m.def("enable_profiling", &dc::enable_profiling);
m.def("is_profiling", &dc::is_profiling);
m.def("reset", &dc::reset);
m.def("invalidate_gathered_param", &dc::invalidate_gathered_param); // Z3
m.def("clear_all_gathered_params", &dc::clear_all_gathered_params); // Z3
}
Import
import torch
from deepspeed.ops import dc # Imports compiled extension
# Use registered custom operators
result = torch.ops.dc.allgather_param(tensor, graph_id=0, id=0, dtype=None)
result = torch.ops.dc.wait_allgather(result, graph_id=0, id=0)
torch.ops.dc.prefetch_params_fused(
graph_id=0,
params=[p1, p2],
ids=[0, 1],
dtypes=[torch.float16, torch.bfloat16]
)
# Reduce gradients
torch.ops.dc.reduce_grad(grad_tensor, graph_id=0, id=0)
# Lifecycle hooks
torch.ops.dc.end_backward(graph_id=0)
# Use Python bindings for registration
dc.init(process_group, config, bucket_size)
dc.register_graph_z3(graph_id=0, ds_ids=[0, 1, 2])
dc.register_z3_param(ds_id=0, ds_shape=[1024, 768], ds_tensor=shard, grad_buffer=grad)
I/O Contract
Custom Operators
| Operator | Schema | Description |
|---|---|---|
| allgather_param | (Tensor, int, int, ScalarType?) -> Tensor | Gathers partitioned parameter from all ranks |
| prefetch_params_fused | (int, Tensor[], int[], ScalarType[]?) -> () | Batches multiple allgathers |
| wait_allgather | (Tensor(a), int, int) -> Tensor(a) | Synchronizes allgather stream |
| release_param | (Tensor(a), int, int, int) -> Tensor(a) | Releases gathered parameter memory |
| reduce_grad | (Tensor, int, int) -> Tensor | Reduces gradient across ranks |
| free_tensors | (Tensor[]) -> () | Frees large activation tensors |
| offload_tensor | (Tensor, int, int) -> Tensor | Offloads tensor to CPU pinned memory |
| reload_tensor | (Tensor, int, int) -> Tensor | Reloads tensor from CPU to GPU |
| wait_offload | (Tensor, int, int) -> Tensor | Waits for offload completion |
| wait_reload | (Tensor, int, int) -> Tensor | Waits for reload completion |
| offload_parameter | (Tensor, int, int) -> () | Offloads parameter to CPU |
| reload_parameter | (Tensor, int, int) -> () | Reloads parameter from CPU |
| end_backward | (int) -> () | Flushes reduce buckets at end of backward |
Python Bindings
| Function | Parameters | Description |
|---|---|---|
| init | (ProcessGroup, config, int64_t) | Initializes runtime with process group |
| cleanup | () | Destroys NCCL communicator and clears state |
| register_param | (int, int64_t[], Tensor, Tensor, int64_t) | Registers Z1 parameter |
| register_graph_z1 | (int, int[]) | Registers Z1 graph executor |
| register_graph_z2 | (int, int[]) | Registers Z2 graph executor |
| register_z3_param | (int, int64_t[], Tensor, Tensor, bool) | Registers Z3 partitioned parameter |
| register_graph_z3 | (int, int[]) | Registers Z3 graph executor |
| set_persistent | (int) | Marks Z3 parameter as persistent (keep gathered) |
| start_forward | () | Begins forward pass lifecycle |
| end_forward | () | Ends forward pass lifecycle |
| start_backward | (bool) | Begins backward pass with update flag |
| enable_profiling | (bool) | Enables/disables profiling mode |
| is_profiling | () -> bool | Checks if profiling enabled |
| reset | () | Clears executors |
| invalidate_gathered_param | (int) | Invalidates Z3 gathered parameter (profiling) |
| clear_all_gathered_params | () | Clears all Z3 gathered parameters |
Usage Examples
import torch
import torch.distributed as dist
from deepspeed.ops import dc
# Initialize extension
dist.init_process_group("nccl")
pg = dist.new_group()
config = {"symmetric_memory": False, "double_buffer": True, "free_activation_threshold": 1_000_000,
"sync_before_reduce": False, "sync_after_reduce": False,
"sync_before_allgather": False, "sync_after_allgather": False}
dc.init(pg, config, initial_reduce_bucket_size=100_000_000)
# Register ZeRO-3 parameters
for i, (shard, grad_buf) in enumerate(param_shards):
dc.register_z3_param(
ds_id=i,
ds_shape=original_shape,
ds_tensor=shard,
grad_buffer=grad_buf,
persistent=False
)
dc.register_graph_z3(graph_id=0, ds_ids=list(range(len(param_shards))))
# Using custom operators in compiled model
@torch.compile(backend="inductor")
def forward_with_allgather(x, param_shard, graph_id, param_id):
# Allgather parameter
param = torch.ops.dc.allgather_param(param_shard, graph_id, param_id, dtype=torch.float16)
param = torch.ops.dc.wait_allgather(param, graph_id, param_id)
# Use parameter
output = torch.matmul(x, param.T)
# Release memory
dummy = torch.empty(0)
dummy = torch.ops.dc.release_param(dummy, graph_id, param_id, n_users=1)
return output
# Training loop with lifecycle management
dc.start_forward()
output = forward_with_allgather(input, param_shards[0], 0, 0)
dc.end_forward()
dc.start_backward(update=True)
loss = output.sum()
loss.backward()
torch.ops.dc.end_backward(graph_id=0)
# Profiling mode for shape inference
dc.enable_profiling(True)
traced_graph = torch.compile(forward_with_allgather, backend="eager")
_ = traced_graph(sample_input, param_shards[0], 0, 0)
dc.enable_profiling(False)
# Activation memory management
large_activations = [act for act in activations if act.numel() > 1_000_000]
torch.ops.dc.free_tensors(large_activations)
# Cleanup
dc.cleanup()
Implementation Details
Operator Aliasing Annotations
The schema uses PyTorch's aliasing notation:
- Tensor(a) indicates the output aliases the input (shares underlying memory)
- Used for wait_allgather, release_param to signal no-copy operations
- Helps torch.compile understand memory dependencies and optimize
Meta Backend Implementation
Meta kernels provide shape inference without execution:
- allgather_param_meta returns empty tensor with correct shape and dtype
- reduce_grad_meta returns empty tensor
- Enables symbolic tracing and graph capture in profiling mode
Dispatch Key Selection
- CPU/CUDA dispatch keys route to actual implementations (z1/z2/z3.cpp)
- Meta dispatch key routes to _meta variants for shape inference
- Undefined dispatch key used for end_backward (no tensor arguments)
Backend Consistency
Both CPU and CUDA implementations call the same functions because:
- ProcessGroup handles device abstraction
- NCCL operations work on CUDA tensors regardless of dispatch backend
- CPU dispatch supports CPU-based testing and debugging
Extension Module Loading
PYBIND11_MODULE creates a Python extension module:
- Module name determined by TORCH_EXTENSION_NAME macro
- Compiled as shared library (.so on Linux, .pyd on Windows)
- Imported in Python as regular module with custom ops auto-registered
Related Pages
- Environment:Deepspeedai_DeepSpeed_CUDA_GPU_Environment
- Implementation:Deepspeedai_DeepSpeed_DeepCompile_Runtime
- Implementation:Deepspeedai_DeepSpeed_DeepCompile_Header
- Implementation:Deepspeedai_DeepSpeed_ZeRO3_DeepCompile
- Implementation:Deepspeedai_DeepSpeed_ZeRO1_DeepCompile
- Implementation:Deepspeedai_DeepSpeed_ZeRO2_DeepCompile
- Concept:PyTorch_Custom_Operators
- Concept:Torch_Compile