Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FMInference FlexLLMGen DeepSpeed Inference Engine

From Leeroopedia


Field Value
Sources Repo: FlexLLMGen
Domains Inference, Model_Parallelism, Performance_Optimization
Last Updated 2026-02-09 00:00 GMT

Overview

Vendored DeepSpeed inference engine that wraps a PyTorch model with optimized inference capabilities, including tensor parallelism, kernel injection, CUDA graph support, checkpoint loading, and model profiling.

Description

InferenceEngine is the main entry point for DeepSpeed inference. It extends torch.nn.Module and wraps a user model with a comprehensive set of inference optimizations. The engine is initialized with a DeepSpeedInferenceConfig that controls tensor parallelism, kernel injection, dtype conversion, checkpoint loading, and CUDA graph usage.

Key capabilities:

  • Tensor parallelism -- Creates model-parallel process groups and distributes the model across GPUs. Supports both explicit MPU (model parallel unit) and automatic TP group creation. Synchronizes RNG states across TP ranks for consistent behavior.
  • Kernel injection -- Replaces standard transformer layers with DeepSpeed's optimized inference kernels via replace_transformer_layer. Supports both injection-dict-based replacement (user specifies which modules to replace) and automatic replacement (replace_method="auto").
  • Checkpoint loading -- Loads model weights from directories (with tag-based versioning), JSON-specified checkpoint paths, or state dict files. Supports sharded checkpoints for tensor-parallel loading, with optional INT8 quantization during load. Handles both module and model state dict keys.
  • CUDA graph capture -- When enabled, captures the forward pass into a CUDA graph after warmup iterations, then replays it for subsequent calls. This eliminates CPU-side kernel launch overhead for static computation graphs.
  • Dtype conversion -- Converts the model to the configured dtype (float16, bfloat16, float32, or int8) after loading.
  • MoE support -- Detects Mixture of Experts layers and creates expert-parallel and expert-model-parallel groups for distributed expert execution.
  • Model profiling -- Optional timing of forward passes using CUDA events or wall-clock time, accessible via model_times().
  • HuggingFace compatibility -- Includes a hack to remove BLOOM's _prepare_attn_mask preprocessing, and passes the HF model config through for use by injection policies.

This is AUTO_KEEP vendored code from DeepSpeed.

Code Reference

Field Value
Repository FlexLLMGen
File benchmark/third_party/DeepSpeed/deepspeed/inference/engine.py
Lines 1-531

Key Class:

class InferenceEngine(Module):
    inference_mp_group = None
    inference_ep_group = None
    expert_mp_group = None

    def __init__(self, model, config): ...
    def forward(self, *inputs, **kwargs): ...
    def _apply_injection_policy(self, config, client_module=None): ...
    def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): ...
    def _create_cuda_graph(self, *inputs, **kwargs): ...
    def _graph_replay(self, *inputs, **kwargs): ...
    def profile_model_time(self, use_cuda_events=True): ...
    def model_times(self): ...
    def load_model_with_checkpoint(self, r_module): ...

I/O Contract

Inputs

Parameter Type Required Description
model torch.nn.Module Yes The PyTorch model to optimize for inference
config DeepSpeedInferenceConfig Yes Configuration specifying TP size, dtype, checkpoint path, injection policy, CUDA graph, etc.

Outputs (forward)

Output Type Description
outputs varies The model's forward pass output, executed via CUDA graph replay or direct call

Initialization Sequence

  1. Extract model config and set HF model config on DSPolicy.
  2. Set up tensor parallelism (MPU or auto TP group creation).
  3. Detect and set up MoE expert parallel groups if present.
  4. Load checkpoint if specified and not using kernel injection.
  5. Convert model to target dtype.
  6. Apply injection policy (replace transformer layers with optimized kernels).
  7. Move model to CUDA device.
  8. Synchronize RNG states across TP ranks.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment