Implementation:Deepspeedai DeepSpeed Transformer CUDA
| Knowledge Sources | |
|---|---|
| Domains | Transformer, Training, CUDA_Kernels, Neural_Networks |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Complete BERT-style transformer layer implementation in CUDA with fused operations, supporting both training and inference with configurable layer normalization modes.
Description
The BertTransformerLayer class implements a full transformer encoder layer with highly optimized CUDA kernels for training. It orchestrates multi-head self-attention, feed-forward networks, layer normalization, dropout, and residual connections through a sequence of fused kernel launches. The implementation supports both pre-layer normalization (normalize before attention/FFN) and post-layer normalization (normalize after residual addition) configurations. Memory management is handled through a centralized TrainingContext that provides workspace allocation, with support for activation checkpointing to reduce memory footprint by selectively storing intermediate values needed for backward passes. The class uses composition of specialized sub-modules (FeedForward, Normalize_Layer, Softmax, Gelu, Dropout, StridedBatchGemm) coordinated through a Forward method that implements the complete layer computation graph while maintaining correctness for gradient computation.
Usage
Use this class when implementing high-performance transformer training in CUDA. It provides significantly better performance than PyTorch native operations through kernel fusion and workspace reuse, particularly beneficial for large models where memory bandwidth and kernel launch overhead dominate.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/transformer/ds_transformer_cuda.cpp
Signature
template <typename T>
class BertTransformerLayer {
public:
BertTransformerLayer(unsigned layer_id, unsigned batch_size,
unsigned hidden_size, unsigned num_heads,
unsigned intermediate_size, unsigned seq_length,
float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio,
float layer_norm_eps,
bool pre_or_postLayerNorm,
const std::vector<std::array<int, 3>>& gemm_algos,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint,
bool stochastic_mode);
void Forward(unsigned bsz, const T* input_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr, const T* attn_qkvb_ptr,
const T* attn_ow_ptr, const T* attn_ob_ptr,
const T* attn_nw_ptr, const T* attn_nb_ptr,
const T* inter_w_ptr, const T* inter_b_ptr,
const T* output_w_ptr, const T* output_b_ptr,
const T* norm_w_ptr, const T* norm_b_ptr,
T* out_ptr, T* inp_norm_ptr,
T* q_tf_ptr, T* k_tf_ptr, T* v_tf_ptr,
T* soft_out_ptr, T* ctx_bufB_ptr,
T* attn_o_inp_ptr, T* add_res_ptr,
T* ff1_inp_ptr, T* gelu_inp_ptr, T* ff2_inp_ptr);
};
Import
#include "ds_transformer_cuda.h"
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| input_ptr | const T* | Input activations [batch×seq×hidden] |
| input_mask_ptr | const T* | Attention mask [batch×1×seq×seq] |
| attn_qkvw_ptr | const T* | QKV projection weights [hidden×3×hidden] |
| attn_qkvb_ptr | const T* | QKV projection biases [3×hidden] |
| inter_w_ptr | const T* | FF intermediate weights [hidden×inter] |
| output_w_ptr | const T* | FF output weights [inter×hidden] |
| norm_w_ptr | const T* | Final layer norm weights [hidden] |
| Output | Type | Description |
|---|---|---|
| out_ptr | T* | Layer output activations [batch×seq×hidden] |
| Intermediate buffers | T* | Various intermediate activations for backprop |
Usage Examples
Basic Transformer Layer Setup:
#include "ds_transformer_cuda.h"
// Configuration
unsigned layer_id = 0;
unsigned batch_size = 32;
unsigned seq_length = 512;
unsigned hidden_size = 768;
unsigned num_heads = 12;
unsigned intermediate_size = 3072;
float dropout = 0.1;
float layer_norm_eps = 1e-5;
bool pre_layer_norm = true;
// Algorithm IDs from tuning
std::vector<std::array<int, 3>> gemm_algos = {
{99, 99, 99}, // QKV projection
{99, 99, 99}, // FF1
{99, 99, 99}, // FF2
{99, 99, 99}, // Attention scores
{99, 99, 99} // Attention context
};
// Create layer
BertTransformerLayer<__half> layer(
layer_id, batch_size, hidden_size, num_heads,
intermediate_size, seq_length,
dropout, dropout, layer_norm_eps, pre_layer_norm,
gemm_algos, true, false, false, false);
Forward Pass Execution:
// Allocate buffers
__half *input, *mask, *output;
__half *qkv_w, *qkv_b, *attn_w, *attn_b;
__half *ff_w1, *ff_b1, *ff_w2, *ff_b2;
__half *norm1_w, *norm1_b, *norm2_w, *norm2_b;
// Intermediate buffers
__half *inp_norm, *q_tf, *k_tf, *v_tf;
__half *soft_out, *ctx_buf, *attn_out;
__half *add_res, *ff1_inp, *gelu_inp, *ff2_inp;
// ... allocate and load weights ...
// Forward computation
layer.Forward(batch_size,
input, mask,
qkv_w, qkv_b,
attn_w, attn_b,
norm1_w, norm1_b,
ff_w1, ff_b1,
ff_w2, ff_b2,
norm2_w, norm2_b,
output, inp_norm,
q_tf, k_tf, v_tf,
soft_out, ctx_buf,
attn_out, add_res,
ff1_inp, gelu_inp, ff2_inp);
Multi-Layer Transformer:
class TransformerModel {
std::vector<std::shared_ptr<BertTransformerLayer<__half>>> layers;
__half* workspace;
size_t workspace_size;
public:
TransformerModel(int num_layers, int batch, int seq_len,
int hidden, int heads, int intermediate) {
// Compute workspace size
workspace_size = get_workspace_size<__half>(
batch, seq_len, hidden, intermediate, heads, true, false);
cudaMalloc(&workspace, workspace_size);
TrainingContext::Instance().SetWorkSpace(workspace);
// Create layers
for (int i = 0; i < num_layers; i++) {
auto gemm_algos = tune_layer_algos(batch, seq_len, hidden,
intermediate, heads);
layers.push_back(std::make_shared<BertTransformerLayer<__half>>(
i, batch, hidden, heads, intermediate, seq_len,
0.1, 0.1, 1e-5, true, gemm_algos,
true, false, false, false));
}
}
void forward(__half* input, __half* mask, __half* output,
std::vector<__half*>& weights, int batch) {
__half* layer_input = input;
__half* layer_output = output;
for (int i = 0; i < layers.size(); i++) {
layers[i]->Forward(batch, layer_input, mask,
weights[i * 6 + 0], weights[i * 6 + 1], // QKV
weights[i * 6 + 2], weights[i * 6 + 3], // Attn out
// ... other weights ...
layer_output, /* intermediate buffers */);
layer_input = layer_output;
}
}
};
With Activation Checkpointing:
// Enable checkpointing for memory efficiency
BertTransformerLayer<__half> memory_efficient_layer(
layer_id, batch_size, hidden_size, num_heads,
intermediate_size, seq_length,
dropout, dropout, layer_norm_eps,
true, // pre_layer_norm
gemm_algos,
true, // attn_dropout_checkpoint - save dropout mask
true, // normalize_invertible - recompute activations
true, // gelu_checkpoint - save GELU input
false // stochastic_mode
);
// Reduced memory usage but slightly slower
// Saves: ~2-3× activation memory at cost of ~10-15% speed
Workspace Management:
unsigned compute_workspace(int batch, int seq, int hidden,
int intermediate, int heads) {
// Base activations
size_t base = 4 * batch * seq * hidden;
// Training extras
size_t training_extra = 2 * batch * seq * hidden;
// Max of intermediate or attention
size_t compute = std::max(
batch * seq * intermediate,
2 * batch * heads * seq * seq);
// Optional GELU checkpoint
size_t gelu = 2 * batch * seq * intermediate;
return (base + training_extra + compute + gelu) * sizeof(__half);
}
Related Pages
- Custom CUDA Layers - Underlying kernel implementations
- Normalize Layer - Layer normalization component
- Inference PT Binding - Inference-optimized variant