Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FMInference FlexLLMGen DeepSpeed Inference PT Binding

From Leeroopedia


Knowledge Sources
Domains CUDA, PyTorch, Deep Learning Inference, C++ Bindings
Last Updated 2026-02-09 12:00 GMT

Overview

A comprehensive PyTorch C++ extension binding file that exposes over 40 DeepSpeed inference CUDA kernels to Python, covering softmax, attention, GEMM, activation functions, layer normalization, quantization, and residual operations.

Description

This file serves as the bridge layer between DeepSpeed's optimized CUDA inference kernels and PyTorch's Python-facing tensor API. It implements template functions (parameterized on float or __half) that:

  • Accept at::Tensor inputs from PyTorch.
  • Extract raw data pointers and dimension information.
  • Invoke the appropriate CUDA kernel launchers or cuBLAS GEMM wrappers.
  • Return results as at::Tensor outputs, often using at::from_blob to wrap pre-allocated workspace memory as tensors without copying.

Key function groups include:

  • Softmax and attention: ds_softmax<T> (standalone softmax with mask and ALiBi support), ds_softmax_context<T> (fused attention with KV cache, rotary position embeddings, and autoregressive generation support), and ds_softmax_context1<T> (simpler attention without KV cache).
  • GEMM operations: ds_qkv_gemm<T> (fused layer norm + QKV projection), ds_mlp_gemm<T> (fused residual-LN + MLP with GELU/ReLU), ds_vector_matmul<T> (general matrix-vector multiply), ds_linear_layer<T> (with optional flash attention reshape), and fused_gemm_gelu<T> (GEMM + GELU fusion).
  • Activations: ds_bias_gelu<T>, ds_bias_relu<T>, ds_bias_geglu (for Stable Diffusion).
  • Layer normalization: ds_layer_norm, ds_layer_norm_residual, ds_layer_norm_residual_store.
  • Quantization support: quantized_gemm<T> with dequantize-then-GEMM pattern for INT8 weights.
  • Residual operations: residual_add_bias<T> with GPT-J style variant.
  • Rotary position embeddings: apply_rotary_pos_emb supporting both rotate-half and rotate-every-two modes.
  • Workspace management: allocate_workspace<T> pre-allocates GPU memory for the inference context.

The file distinguishes between GPT-type and BERT-type models by inspecting the attention mask dimensionality (>2 dims indicates GPT autoregressive, 2 dims indicates BERT bidirectional).

All functions are registered via PYBIND11_MODULE with separate FP32 and FP16 entry points (e.g., softmax_fp32, softmax_fp16).

Usage

This module is compiled as a PyTorch C++ extension and loaded by DeepSpeed's inference engine to execute optimized transformer operations. It replaces PyTorch's native operations with fused CUDA kernels for significant inference speedup.

Code Reference

Source Location

Signature

// Core attention operations
template <typename T>
at::Tensor ds_softmax(at::Tensor& attn_scores, at::Tensor& attn_mask, at::Tensor& alibi,
                      bool triangular, bool recompute, bool local_attention,
                      int window_size, bool async_op, float layer_scale,
                      int head_offset, int mp_size);

template <typename T>
std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value, at::Tensor& attn_mask,
                                           int rotary_dim, bool rotate_half,
                                           bool rotate_every_two, int heads,
                                           float norm_factor, ...);

// Fused GEMM operations
template <typename T>
std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input, at::Tensor& weight, ...);

template <typename T>
std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input, at::Tensor& residual, ...);

// PYBIND11 module with 40+ function bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ... }

Import

# Loaded as a PyTorch C++ extension by DeepSpeed inference engine
import deepspeed_transformer_inference

I/O Contract

Inputs

Name Type Required Description
input at::Tensor Yes Input tensor, typically shape (batch, seq_len, hidden_dim).
weight at::Tensor Yes Weight matrix for linear projections.
bias at::Tensor Conditional Bias vector, required when add_bias is true.
gamma, beta at::Tensor Conditional Layer normalization parameters, required for fused LN+GEMM operations.
attn_mask at::Tensor Conditional Attention mask; dimensionality determines GPT vs BERT model type.
q_scale at::Tensor Conditional Quantization scale for INT8 weight dequantization.
epsilon float Conditional Layer normalization epsilon value.
q_int8 bool No Whether to use INT8 quantized GEMM path.

Outputs

Name Type Description
output at::Tensor Result tensor, shape depends on the operation.
intermediate std::vector<at::Tensor> For composite operations (e.g., ds_qkv_gemm returns [output, inp_norm]).

Usage Examples

# Using inference bindings from Python (via DeepSpeed's inference engine)
import deepspeed_transformer_inference as ds_inf

# Allocate workspace for inference
ds_inf.allocate_workspace_fp16(hidden_dim=1024, num_heads=16,
                                prompt_length=512, batch_size=1,
                                num_layers=24, mp_size=1)

# Run softmax with attention mask
output = ds_inf.softmax_fp16(attn_scores, attn_mask, alibi,
                              triangular=True, recompute=False,
                              local_attention=False, window_size=256,
                              async_op=False, layer_scale=1.0,
                              head_offset=0, mp_size=1)

# Run fused QKV GEMM with layer norm
qkv_output, norm_output = ds_inf.qkv_gemm_fp16(
    input, weight, q_scale, bias, gamma, beta,
    epsilon=1e-5, add_bias=True, num_layers=24,
    external_cache=False, mp_size=1, rank=0, q_int8=False)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment