Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Deepspeedai DeepSpeed SHM Interface

From Leeroopedia


Knowledge Sources
Domains Communication, PyTorch, CPU_Optimization
Last Updated 2026-02-09 00:00 GMT

Overview

PyTorch TORCH_LIBRARY interface for shared memory allreduce operations with functionalization support for in-place mutations.

Description

The SHM Interface provides the PyTorch integration layer for DeepSpeed's shared memory based allreduce implementation, exposing it as native PyTorch operators through the TORCH_LIBRARY mechanism. It defines two operator variants: inference_all_reduce (out-of-place) and inference_all_reduce_ (in-place), both of which leverage the optimized SHM communication primitives for low-latency CPU-based distributed inference.

The implementation includes full support for PyTorch's functionalization system, which is critical for compatibility with torch.compile and other graph-based optimizations. The functionalization glue code transforms the in-place mutation (inference_all_reduce_) into a functional operation by unwrapping functional tensors, calling the out-of-place version, and then replacing the input with the result. This allows the operator to work correctly in contexts that require functional semantics while maintaining the performance benefits of in-place operations at the implementation level.

The interface handles three dispatch keys: CPU (actual implementation), Meta (shape inference for graph compilation), and Functionalize (mutation handling). It performs device type validation, ensures tensor contiguity for optimal memory access patterns, and delegates to the underlying SHM allreduce implementation when all ranks are local. The initialization function sets up shared memory based on environment variables (LOCAL_SIZE, MASTER_ADDR, MASTER_PORT) and is exposed via PyBind11 for Python access.

Usage

Use the SHM Interface when performing distributed inference on CPU with PyTorch, especially when integrating with torch.compile or other PyTorch compiler features that require functional operator semantics. The operators are automatically registered in the deepspeed namespace and can be called directly or through DeepSpeed's communication API.

Code Reference

Source Location

Signature

// PyBind11 exported function
void initialize(int size, int rank);

// TORCH_LIBRARY operator definitions (deepspeed namespace)
torch::Tensor inference_all_reduce(const torch::Tensor& self);
torch::Tensor& inference_all_reduce_(torch::Tensor& self);

// CPU dispatch implementations
torch::Tensor inference_all_reduce_cpu(const torch::Tensor& self_);
torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_);

// Meta dispatch for shape inference
torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_);
torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_);

// Functionalization glue
at::Tensor& inference_all_reduce__functionalization_glue(at::Tensor& x);

// Internal implementation (delegates to SHM)
void inference_all_reduce_(torch::Tensor& data, int op);

Import

import torch
import deepspeed

# Initialize the SHM interface
# Typically done internally by DeepSpeed initialization
deepspeed.init_distributed()

# Option 1: Use through DeepSpeed communication API
tensor = torch.randn(1024, 1024, dtype=torch.bfloat16)
deepspeed.comm.inference_all_reduce(tensor)

# Option 2: Call PyTorch operator directly
result = torch.ops.deepspeed.inference_all_reduce(tensor)

# Option 3: In-place variant (modifies tensor)
torch.ops.deepspeed.inference_all_reduce_(tensor)

# Works with torch.compile
@torch.compile
def my_inference_step(x):
    # Operator is fully functionalized for graph compilation
    return torch.ops.deepspeed.inference_all_reduce(x)

I/O Contract

initialize(size, rank)
Parameter Type Description
size int Total number of local ranks (from LOCAL_SIZE env)
rank int Current rank identifier
Effect Initializes shared memory buffers if LOCAL_SIZE == size
inference_all_reduce(self) -> Tensor
Parameter Type Description
self const Tensor& Input tensor (not modified)
Returns Tensor New tensor with allreduced values
Constraints Device must be CPU, supports BF16/FP16/FP32
inference_all_reduce_(self) -> Tensor
Parameter Type Description
self Tensor& Input tensor (modified in-place)
Returns Tensor& Reference to modified input tensor
Constraints Device must be CPU, tensor made contiguous
Dispatch Keys
Dispatch Key Implementation Purpose
CPU inference_all_reduce__cpu Actual SHM allreduce computation
Meta inference_all_reduce__meta Shape inference (returns empty_like)
Functionalize inference_all_reduce__functionalization_glue Handles mutations for torch.compile
Supported Data Types
Type Support Fallback Behavior
BFloat16 Full support SHM allreduce
Half (FP16) Full support SHM allreduce
Float Full support SHM allreduce
Other types Not supported Returns without modification

Usage Examples

import os
import torch
import deepspeed

# Setup environment
os.environ['LOCAL_SIZE'] = '8'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'

# Initialize DeepSpeed
deepspeed.init_distributed(rank=rank, world_size=8)

# Example 1: Out-of-place allreduce
input_tensor = torch.randn(512, 512, dtype=torch.bfloat16)
output_tensor = torch.ops.deepspeed.inference_all_reduce(input_tensor)
# input_tensor remains unchanged, output_tensor has reduced values

# Example 2: In-place allreduce (more efficient)
tensor = torch.randn(1024, 1024, dtype=torch.bfloat16)
torch.ops.deepspeed.inference_all_reduce_(tensor)
# tensor now contains the allreduced values

# Example 3: With torch.compile
@torch.compile
def inference_with_allreduce(activations):
    # Functionalization converts in-place to out-of-place automatically
    output = torch.ops.deepspeed.inference_all_reduce(activations)
    return output * 2.0

compiled_fn = inference_with_allreduce
result = compiled_fn(torch.randn(256, 256, dtype=torch.float32))

# Example 4: Mixed with standard PyTorch operations
class DistributedInferenceLayer(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.linear = torch.nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        # Standard PyTorch operation
        x = self.linear(x)
        # DeepSpeed SHM allreduce
        x = torch.ops.deepspeed.inference_all_reduce(x)
        return torch.relu(x)

layer = DistributedInferenceLayer(512)
output = layer(torch.randn(32, 512, dtype=torch.bfloat16))

# Example 5: Handling non-contiguous tensors
non_contiguous = torch.randn(128, 128).t()  # Transpose creates non-contiguous
# The operator automatically makes it contiguous before processing
torch.ops.deepspeed.inference_all_reduce_(non_contiguous)

# Example 6: Type checking and fallback
def safe_allreduce(tensor):
    if tensor.dtype in [torch.bfloat16, torch.float16, torch.float32]:
        if tensor.device.type == 'cpu':
            return torch.ops.deepspeed.inference_all_reduce(tensor)
    # Fallback to standard communication
    return deepspeed.comm.all_reduce(tensor)

Implementation Details

TORCH_LIBRARY Registration

// Define operators in deepspeed namespace
TORCH_LIBRARY(deepspeed, m) {
    m.def("inference_all_reduce(Tensor self) -> Tensor");
    m.def("inference_all_reduce_(Tensor(a!) self) -> Tensor(a!)");
}

// Register CPU implementations
TORCH_LIBRARY_IMPL(deepspeed, CPU, m) {
    m.impl("inference_all_reduce", inference_all_reduce_cpu);
    m.impl("inference_all_reduce_", inference_all_reduce__cpu);
}

// Register Meta implementations for shape inference
TORCH_LIBRARY_IMPL(deepspeed, Meta, m) {
    m.impl("inference_all_reduce", inference_all_reduce_meta);
    m.impl("inference_all_reduce_", inference_all_reduce__meta);
}

// Register functionalization handler
TORCH_LIBRARY_IMPL(deepspeed, Functionalize, m) {
    m.impl("inference_all_reduce_", inference_all_reduce__functionalization_glue);
}

Functionalization Pattern

The functionalization glue follows this pattern:

  1. Unwrap: Extract raw tensor from FunctionalTensorWrapper
  2. Call: Invoke out-of-place version (inference_all_reduce)
  3. Replace: Update functional tensor with result via replace_()
  4. Commit: Finalize mutation and synchronize

This allows in-place operators to work with PyTorch's functional transformation system.

Device and Type Validation

torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_) {
    // Validate device type
    TORCH_INTERNAL_ASSERT(self_.device().type() == torch::DeviceType::CPU);

    // Ensure contiguity for optimal memory access
    torch::Tensor self_tensor = self_.contiguous();

    // Delegate to implementation (op=0 means SUM)
    inference_all_reduce_(self_tensor, 0);

    return self_;
}

Meta Dispatch for Shape Inference

// Out-of-place: return new tensor with same shape
torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_) {
    return torch::empty_like(self_);
}

// In-place: return reference to self
torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_) {
    return self_;
}

Environment Variable Handling

  • LOCAL_SIZE: Number of ranks on local machine
  • MASTER_ADDR: Master node address for coordination
  • MASTER_PORT: Master node port for coordination

These are checked during initialize() to determine if SHM optimization is applicable.

Integration with SHM Backend

The interface layer checks for local rank configuration and delegates to:

  • shm_initialize(): Sets up shared memory regions
  • all_reduce_outer_loop(): Performs actual SHM-based reduction
  • Supports BFloat16, Float16, and Float32 tensors only

PyBind11 Exposure

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("initialize", &initialize, "shm initialize");
}

This allows Python code to explicitly initialize the SHM backend if needed.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment