Implementation:FMInference FlexLLMGen DeepSpeed Inference Engine
| Field | Value |
|---|---|
| Sources | Repo: FlexLLMGen |
| Domains | Inference, Model_Parallelism, Performance_Optimization |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Vendored DeepSpeed inference engine that wraps a PyTorch model with optimized inference capabilities, including tensor parallelism, kernel injection, CUDA graph support, checkpoint loading, and model profiling.
Description
InferenceEngine is the main entry point for DeepSpeed inference. It extends torch.nn.Module and wraps a user model with a comprehensive set of inference optimizations. The engine is initialized with a DeepSpeedInferenceConfig that controls tensor parallelism, kernel injection, dtype conversion, checkpoint loading, and CUDA graph usage.
Key capabilities:
- Tensor parallelism -- Creates model-parallel process groups and distributes the model across GPUs. Supports both explicit MPU (model parallel unit) and automatic TP group creation. Synchronizes RNG states across TP ranks for consistent behavior.
- Kernel injection -- Replaces standard transformer layers with DeepSpeed's optimized inference kernels via replace_transformer_layer. Supports both injection-dict-based replacement (user specifies which modules to replace) and automatic replacement (replace_method="auto").
- Checkpoint loading -- Loads model weights from directories (with tag-based versioning), JSON-specified checkpoint paths, or state dict files. Supports sharded checkpoints for tensor-parallel loading, with optional INT8 quantization during load. Handles both module and model state dict keys.
- CUDA graph capture -- When enabled, captures the forward pass into a CUDA graph after warmup iterations, then replays it for subsequent calls. This eliminates CPU-side kernel launch overhead for static computation graphs.
- Dtype conversion -- Converts the model to the configured dtype (float16, bfloat16, float32, or int8) after loading.
- MoE support -- Detects Mixture of Experts layers and creates expert-parallel and expert-model-parallel groups for distributed expert execution.
- Model profiling -- Optional timing of forward passes using CUDA events or wall-clock time, accessible via model_times().
- HuggingFace compatibility -- Includes a hack to remove BLOOM's _prepare_attn_mask preprocessing, and passes the HF model config through for use by injection policies.
This is AUTO_KEEP vendored code from DeepSpeed.
Code Reference
| Field | Value |
|---|---|
| Repository | FlexLLMGen |
| File | benchmark/third_party/DeepSpeed/deepspeed/inference/engine.py |
| Lines | 1-531 |
Key Class:
class InferenceEngine(Module):
inference_mp_group = None
inference_ep_group = None
expert_mp_group = None
def __init__(self, model, config): ...
def forward(self, *inputs, **kwargs): ...
def _apply_injection_policy(self, config, client_module=None): ...
def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): ...
def _create_cuda_graph(self, *inputs, **kwargs): ...
def _graph_replay(self, *inputs, **kwargs): ...
def profile_model_time(self, use_cuda_events=True): ...
def model_times(self): ...
def load_model_with_checkpoint(self, r_module): ...
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
| model | torch.nn.Module | Yes | The PyTorch model to optimize for inference |
| config | DeepSpeedInferenceConfig | Yes | Configuration specifying TP size, dtype, checkpoint path, injection policy, CUDA graph, etc. |
Outputs (forward)
| Output | Type | Description |
|---|---|---|
| outputs | varies | The model's forward pass output, executed via CUDA graph replay or direct call |
Initialization Sequence
- Extract model config and set HF model config on DSPolicy.
- Set up tensor parallelism (MPU or auto TP group creation).
- Detect and set up MoE expert parallel groups if present.
- Load checkpoint if specified and not using kernel injection.
- Convert model to target dtype.
- Apply injection policy (replace transformer layers with optimized kernels).
- Move model to CUDA device.
- Synchronize RNG states across TP ranks.