Implementation:FMInference FlexLLMGen DeepSpeed Transformer CUDA
| Knowledge Sources | |
|---|---|
| Domains | CUDA, Transformer, Deep Learning Training |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
A CUDA C++ implementation of DeepSpeed's fused BERT transformer layer with complete forward and backward pass logic, exposed to Python via pybind11/PyTorch C++ extensions.
Description
This file implements BertTransformerLayer<T>, a template class (parameterized on float or __half) that fuses an entire transformer encoder layer into a single C++ class. The layer includes:
- QKV linear projection via a single fused GEMM (3 * hidden_size output) using cuBLAS.
- Multi-head self-attention with bias-add, reshape to (batch, heads, seq, head_dim), strided batched GEMM for attention scores, softmax with masking, attention probability dropout, context computation via another strided batched GEMM, and reshape back.
- Attention output projection with dropout and residual addition.
- Layer normalization (pre-LN or post-LN configuration) with optional mean-based checkpoint variant.
- Feed-forward network with two linear layers, GELU activation, and optional GELU checkpointing.
- Output dropout and residual addition.
The class manages a static unordered map (s_transformer_layers) to store created layer instances by layer ID. Helper functions create_transformer_layer<T>, ds_transformer_forward<T>, and ds_transformer_backward<T> serve as the entry points bound via PYBIND11_MODULE.
The implementation supports stochastic mode (skipping stream synchronization for speed), normalize_invertible mode (sharing buffers), attention dropout checkpointing, and GELU checkpointing to trade compute for memory.
Usage
This module is compiled as a PyTorch C++ extension and called from the DeepSpeed Python runtime to execute fused transformer layers during training with both FP32 and FP16 precision.
Code Reference
Source Location
- Repository: FMInference_FlexLLMGen
- File: benchmark/third_party/DeepSpeed/csrc/transformer/ds_transformer_cuda.cpp
- Lines: 1-1049
Signature
template <typename T>
class BertTransformerLayer {
public:
BertTransformerLayer(unsigned layer_id, unsigned batch_size, unsigned hidden_size,
unsigned num_heads, unsigned intermediate_size, unsigned seq_length,
float attn_prob_dropout_ratio, float hidden_output_dropout_ratio,
float layer_norm_eps, bool pre_or_postLayerNorm,
const std::vector<std::array<int, 3>>& gemm_algos,
bool attn_dropout_checkpoint, bool normalize_invertible,
bool gelu_checkpoint, bool stochastic_mode);
void Forward(unsigned bsz, const T* input_ptr, const T* input_mask_ptr, ...);
void Backward(unsigned bsz, const T* grad_output_ptr, ...);
};
// PyTorch-facing functions
template <typename T>
std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id, ...);
template <typename T>
std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id, ...);
Import
# From Python, loaded as a PyTorch C++ extension:
import deepspeed_transformer_cuda
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| layer_id | unsigned | Yes | Unique identifier for the transformer layer instance. |
| input | torch::Tensor | Yes | Input tensor of shape (batch_size, seq_length, hidden_size). |
| input_mask | torch::Tensor | Yes | Attention mask tensor. |
| attn_qkvw, attn_qkvb | torch::Tensor | Yes | Fused QKV weight and bias parameters. |
| attn_ow, attn_ob | torch::Tensor | Yes | Attention output projection weight and bias. |
| attn_nw, attn_nb | torch::Tensor | Yes | Attention layer norm weight and bias. |
| inter_w, inter_b | torch::Tensor | Yes | Feed-forward intermediate weight and bias. |
| output_w, output_b | torch::Tensor | Yes | Feed-forward output weight and bias. |
| norm_w, norm_b | torch::Tensor | Yes | Final layer norm weight and bias. |
| training_mode | bool | Yes | Whether to apply dropout. |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch::Tensor | Transformer layer output of shape (batch_size, seq_length, hidden_size). |
| intermediate_tensors | std::vector<torch::Tensor> | Saved tensors for backward pass (17 tensors total including dropout masks, layer norm stats, and intermediate activations). |
Usage Examples
# Creating and running a fused transformer layer from Python
import deepspeed_transformer_cuda as ds_cuda
# Create the layer
ds_cuda.create_transformer_layer_fp16(
layer_id=0, batch_size=8, hidden_dim=1024, num_heads=16,
intermediate_size=4096, attn_dropout_ratio=0.1,
hidden_dropout_ratio=0.1, layer_norm_eps=1e-5, seed=42,
pre_or_postLayerNorm=True, test_gemm=False,
attn_dropout_checkpoint=False, normalize_invertible=False,
gelu_checkpoint=False, stochastic_mode=True
)
# Forward pass
outputs = ds_cuda.forward_fp16(0, input, mask, qkvw, qkvb, ow, ob,
nw, nb, iw, ib, outw, outb, normw, normb,
True, True, False, False, False)