Implementation:Deepspeedai DeepSpeed SHM Interface
| Knowledge Sources | |
|---|---|
| Domains | Communication, PyTorch, CPU_Optimization |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
PyTorch TORCH_LIBRARY interface for shared memory allreduce operations with functionalization support for in-place mutations.
Description
The SHM Interface provides the PyTorch integration layer for DeepSpeed's shared memory based allreduce implementation, exposing it as native PyTorch operators through the TORCH_LIBRARY mechanism. It defines two operator variants: inference_all_reduce (out-of-place) and inference_all_reduce_ (in-place), both of which leverage the optimized SHM communication primitives for low-latency CPU-based distributed inference.
The implementation includes full support for PyTorch's functionalization system, which is critical for compatibility with torch.compile and other graph-based optimizations. The functionalization glue code transforms the in-place mutation (inference_all_reduce_) into a functional operation by unwrapping functional tensors, calling the out-of-place version, and then replacing the input with the result. This allows the operator to work correctly in contexts that require functional semantics while maintaining the performance benefits of in-place operations at the implementation level.
The interface handles three dispatch keys: CPU (actual implementation), Meta (shape inference for graph compilation), and Functionalize (mutation handling). It performs device type validation, ensures tensor contiguity for optimal memory access patterns, and delegates to the underlying SHM allreduce implementation when all ranks are local. The initialization function sets up shared memory based on environment variables (LOCAL_SIZE, MASTER_ADDR, MASTER_PORT) and is exposed via PyBind11 for Python access.
Usage
Use the SHM Interface when performing distributed inference on CPU with PyTorch, especially when integrating with torch.compile or other PyTorch compiler features that require functional operator semantics. The operators are automatically registered in the deepspeed namespace and can be called directly or through DeepSpeed's communication API.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/cpu/comm/shm_interface.cpp
Signature
// PyBind11 exported function
void initialize(int size, int rank);
// TORCH_LIBRARY operator definitions (deepspeed namespace)
torch::Tensor inference_all_reduce(const torch::Tensor& self);
torch::Tensor& inference_all_reduce_(torch::Tensor& self);
// CPU dispatch implementations
torch::Tensor inference_all_reduce_cpu(const torch::Tensor& self_);
torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_);
// Meta dispatch for shape inference
torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_);
torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_);
// Functionalization glue
at::Tensor& inference_all_reduce__functionalization_glue(at::Tensor& x);
// Internal implementation (delegates to SHM)
void inference_all_reduce_(torch::Tensor& data, int op);
Import
import torch
import deepspeed
# Initialize the SHM interface
# Typically done internally by DeepSpeed initialization
deepspeed.init_distributed()
# Option 1: Use through DeepSpeed communication API
tensor = torch.randn(1024, 1024, dtype=torch.bfloat16)
deepspeed.comm.inference_all_reduce(tensor)
# Option 2: Call PyTorch operator directly
result = torch.ops.deepspeed.inference_all_reduce(tensor)
# Option 3: In-place variant (modifies tensor)
torch.ops.deepspeed.inference_all_reduce_(tensor)
# Works with torch.compile
@torch.compile
def my_inference_step(x):
# Operator is fully functionalized for graph compilation
return torch.ops.deepspeed.inference_all_reduce(x)
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| size | int | Total number of local ranks (from LOCAL_SIZE env) |
| rank | int | Current rank identifier |
| Effect | Initializes shared memory buffers if LOCAL_SIZE == size |
| Parameter | Type | Description |
|---|---|---|
| self | const Tensor& | Input tensor (not modified) |
| Returns | Tensor | New tensor with allreduced values |
| Constraints | Device must be CPU, supports BF16/FP16/FP32 |
| Parameter | Type | Description |
|---|---|---|
| self | Tensor& | Input tensor (modified in-place) |
| Returns | Tensor& | Reference to modified input tensor |
| Constraints | Device must be CPU, tensor made contiguous |
| Dispatch Key | Implementation | Purpose |
|---|---|---|
| CPU | inference_all_reduce__cpu | Actual SHM allreduce computation |
| Meta | inference_all_reduce__meta | Shape inference (returns empty_like) |
| Functionalize | inference_all_reduce__functionalization_glue | Handles mutations for torch.compile |
| Type | Support | Fallback Behavior |
|---|---|---|
| BFloat16 | Full support | SHM allreduce |
| Half (FP16) | Full support | SHM allreduce |
| Float | Full support | SHM allreduce |
| Other types | Not supported | Returns without modification |
Usage Examples
import os
import torch
import deepspeed
# Setup environment
os.environ['LOCAL_SIZE'] = '8'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
# Initialize DeepSpeed
deepspeed.init_distributed(rank=rank, world_size=8)
# Example 1: Out-of-place allreduce
input_tensor = torch.randn(512, 512, dtype=torch.bfloat16)
output_tensor = torch.ops.deepspeed.inference_all_reduce(input_tensor)
# input_tensor remains unchanged, output_tensor has reduced values
# Example 2: In-place allreduce (more efficient)
tensor = torch.randn(1024, 1024, dtype=torch.bfloat16)
torch.ops.deepspeed.inference_all_reduce_(tensor)
# tensor now contains the allreduced values
# Example 3: With torch.compile
@torch.compile
def inference_with_allreduce(activations):
# Functionalization converts in-place to out-of-place automatically
output = torch.ops.deepspeed.inference_all_reduce(activations)
return output * 2.0
compiled_fn = inference_with_allreduce
result = compiled_fn(torch.randn(256, 256, dtype=torch.float32))
# Example 4: Mixed with standard PyTorch operations
class DistributedInferenceLayer(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = torch.nn.Linear(hidden_size, hidden_size)
def forward(self, x):
# Standard PyTorch operation
x = self.linear(x)
# DeepSpeed SHM allreduce
x = torch.ops.deepspeed.inference_all_reduce(x)
return torch.relu(x)
layer = DistributedInferenceLayer(512)
output = layer(torch.randn(32, 512, dtype=torch.bfloat16))
# Example 5: Handling non-contiguous tensors
non_contiguous = torch.randn(128, 128).t() # Transpose creates non-contiguous
# The operator automatically makes it contiguous before processing
torch.ops.deepspeed.inference_all_reduce_(non_contiguous)
# Example 6: Type checking and fallback
def safe_allreduce(tensor):
if tensor.dtype in [torch.bfloat16, torch.float16, torch.float32]:
if tensor.device.type == 'cpu':
return torch.ops.deepspeed.inference_all_reduce(tensor)
# Fallback to standard communication
return deepspeed.comm.all_reduce(tensor)
Implementation Details
TORCH_LIBRARY Registration
// Define operators in deepspeed namespace
TORCH_LIBRARY(deepspeed, m) {
m.def("inference_all_reduce(Tensor self) -> Tensor");
m.def("inference_all_reduce_(Tensor(a!) self) -> Tensor(a!)");
}
// Register CPU implementations
TORCH_LIBRARY_IMPL(deepspeed, CPU, m) {
m.impl("inference_all_reduce", inference_all_reduce_cpu);
m.impl("inference_all_reduce_", inference_all_reduce__cpu);
}
// Register Meta implementations for shape inference
TORCH_LIBRARY_IMPL(deepspeed, Meta, m) {
m.impl("inference_all_reduce", inference_all_reduce_meta);
m.impl("inference_all_reduce_", inference_all_reduce__meta);
}
// Register functionalization handler
TORCH_LIBRARY_IMPL(deepspeed, Functionalize, m) {
m.impl("inference_all_reduce_", inference_all_reduce__functionalization_glue);
}
Functionalization Pattern
The functionalization glue follows this pattern:
- Unwrap: Extract raw tensor from FunctionalTensorWrapper
- Call: Invoke out-of-place version (inference_all_reduce)
- Replace: Update functional tensor with result via replace_()
- Commit: Finalize mutation and synchronize
This allows in-place operators to work with PyTorch's functional transformation system.
Device and Type Validation
torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_) {
// Validate device type
TORCH_INTERNAL_ASSERT(self_.device().type() == torch::DeviceType::CPU);
// Ensure contiguity for optimal memory access
torch::Tensor self_tensor = self_.contiguous();
// Delegate to implementation (op=0 means SUM)
inference_all_reduce_(self_tensor, 0);
return self_;
}
Meta Dispatch for Shape Inference
// Out-of-place: return new tensor with same shape
torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_) {
return torch::empty_like(self_);
}
// In-place: return reference to self
torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_) {
return self_;
}
Environment Variable Handling
- LOCAL_SIZE: Number of ranks on local machine
- MASTER_ADDR: Master node address for coordination
- MASTER_PORT: Master node port for coordination
These are checked during initialize() to determine if SHM optimization is applicable.
Integration with SHM Backend
The interface layer checks for local rank configuration and delegates to:
- shm_initialize(): Sets up shared memory regions
- all_reduce_outer_loop(): Performs actual SHM-based reduction
- Supports BFloat16, Float16, and Float32 tensors only
PyBind11 Exposure
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("initialize", &initialize, "shm initialize");
}
This allows Python code to explicitly initialize the SHM backend if needed.