Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed Transformer CUDA

From Leeroopedia
Revision as of 14:47, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Deepspeedai_DeepSpeed_Transformer_CUDA.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Transformer, Training, CUDA_Kernels, Neural_Networks
Last Updated 2026-02-09 00:00 GMT

Overview

Complete BERT-style transformer layer implementation in CUDA with fused operations, supporting both training and inference with configurable layer normalization modes.

Description

The BertTransformerLayer class implements a full transformer encoder layer with highly optimized CUDA kernels for training. It orchestrates multi-head self-attention, feed-forward networks, layer normalization, dropout, and residual connections through a sequence of fused kernel launches. The implementation supports both pre-layer normalization (normalize before attention/FFN) and post-layer normalization (normalize after residual addition) configurations. Memory management is handled through a centralized TrainingContext that provides workspace allocation, with support for activation checkpointing to reduce memory footprint by selectively storing intermediate values needed for backward passes. The class uses composition of specialized sub-modules (FeedForward, Normalize_Layer, Softmax, Gelu, Dropout, StridedBatchGemm) coordinated through a Forward method that implements the complete layer computation graph while maintaining correctness for gradient computation.

Usage

Use this class when implementing high-performance transformer training in CUDA. It provides significantly better performance than PyTorch native operations through kernel fusion and workspace reuse, particularly beneficial for large models where memory bandwidth and kernel launch overhead dominate.

Code Reference

Source Location

Signature

template <typename T>
class BertTransformerLayer {
public:
    BertTransformerLayer(unsigned layer_id, unsigned batch_size,
                        unsigned hidden_size, unsigned num_heads,
                        unsigned intermediate_size, unsigned seq_length,
                        float attn_prob_dropout_ratio,
                        float hidden_output_dropout_ratio,
                        float layer_norm_eps,
                        bool pre_or_postLayerNorm,
                        const std::vector<std::array<int, 3>>& gemm_algos,
                        bool attn_dropout_checkpoint,
                        bool normalize_invertible,
                        bool gelu_checkpoint,
                        bool stochastic_mode);

    void Forward(unsigned bsz, const T* input_ptr,
                const T* input_mask_ptr,
                const T* attn_qkvw_ptr, const T* attn_qkvb_ptr,
                const T* attn_ow_ptr, const T* attn_ob_ptr,
                const T* attn_nw_ptr, const T* attn_nb_ptr,
                const T* inter_w_ptr, const T* inter_b_ptr,
                const T* output_w_ptr, const T* output_b_ptr,
                const T* norm_w_ptr, const T* norm_b_ptr,
                T* out_ptr, T* inp_norm_ptr,
                T* q_tf_ptr, T* k_tf_ptr, T* v_tf_ptr,
                T* soft_out_ptr, T* ctx_bufB_ptr,
                T* attn_o_inp_ptr, T* add_res_ptr,
                T* ff1_inp_ptr, T* gelu_inp_ptr, T* ff2_inp_ptr);
};

Import

#include "ds_transformer_cuda.h"

I/O Contract

Parameter Type Description
input_ptr const T* Input activations [batch×seq×hidden]
input_mask_ptr const T* Attention mask [batch×1×seq×seq]
attn_qkvw_ptr const T* QKV projection weights [hidden×3×hidden]
attn_qkvb_ptr const T* QKV projection biases [3×hidden]
inter_w_ptr const T* FF intermediate weights [hidden×inter]
output_w_ptr const T* FF output weights [inter×hidden]
norm_w_ptr const T* Final layer norm weights [hidden]
Output Type Description
out_ptr T* Layer output activations [batch×seq×hidden]
Intermediate buffers T* Various intermediate activations for backprop

Usage Examples

Basic Transformer Layer Setup:

#include "ds_transformer_cuda.h"

// Configuration
unsigned layer_id = 0;
unsigned batch_size = 32;
unsigned seq_length = 512;
unsigned hidden_size = 768;
unsigned num_heads = 12;
unsigned intermediate_size = 3072;
float dropout = 0.1;
float layer_norm_eps = 1e-5;
bool pre_layer_norm = true;

// Algorithm IDs from tuning
std::vector<std::array<int, 3>> gemm_algos = {
    {99, 99, 99},  // QKV projection
    {99, 99, 99},  // FF1
    {99, 99, 99},  // FF2
    {99, 99, 99},  // Attention scores
    {99, 99, 99}   // Attention context
};

// Create layer
BertTransformerLayer<__half> layer(
    layer_id, batch_size, hidden_size, num_heads,
    intermediate_size, seq_length,
    dropout, dropout, layer_norm_eps, pre_layer_norm,
    gemm_algos, true, false, false, false);

Forward Pass Execution:

// Allocate buffers
__half *input, *mask, *output;
__half *qkv_w, *qkv_b, *attn_w, *attn_b;
__half *ff_w1, *ff_b1, *ff_w2, *ff_b2;
__half *norm1_w, *norm1_b, *norm2_w, *norm2_b;

// Intermediate buffers
__half *inp_norm, *q_tf, *k_tf, *v_tf;
__half *soft_out, *ctx_buf, *attn_out;
__half *add_res, *ff1_inp, *gelu_inp, *ff2_inp;

// ... allocate and load weights ...

// Forward computation
layer.Forward(batch_size,
             input, mask,
             qkv_w, qkv_b,
             attn_w, attn_b,
             norm1_w, norm1_b,
             ff_w1, ff_b1,
             ff_w2, ff_b2,
             norm2_w, norm2_b,
             output, inp_norm,
             q_tf, k_tf, v_tf,
             soft_out, ctx_buf,
             attn_out, add_res,
             ff1_inp, gelu_inp, ff2_inp);

Multi-Layer Transformer:

class TransformerModel {
    std::vector<std::shared_ptr<BertTransformerLayer<__half>>> layers;
    __half* workspace;
    size_t workspace_size;

public:
    TransformerModel(int num_layers, int batch, int seq_len,
                    int hidden, int heads, int intermediate) {
        // Compute workspace size
        workspace_size = get_workspace_size<__half>(
            batch, seq_len, hidden, intermediate, heads, true, false);
        cudaMalloc(&workspace, workspace_size);
        TrainingContext::Instance().SetWorkSpace(workspace);

        // Create layers
        for (int i = 0; i < num_layers; i++) {
            auto gemm_algos = tune_layer_algos(batch, seq_len, hidden,
                                              intermediate, heads);
            layers.push_back(std::make_shared<BertTransformerLayer<__half>>(
                i, batch, hidden, heads, intermediate, seq_len,
                0.1, 0.1, 1e-5, true, gemm_algos,
                true, false, false, false));
        }
    }

    void forward(__half* input, __half* mask, __half* output,
                std::vector<__half*>& weights, int batch) {
        __half* layer_input = input;
        __half* layer_output = output;

        for (int i = 0; i < layers.size(); i++) {
            layers[i]->Forward(batch, layer_input, mask,
                             weights[i * 6 + 0], weights[i * 6 + 1],  // QKV
                             weights[i * 6 + 2], weights[i * 6 + 3],  // Attn out
                             // ... other weights ...
                             layer_output, /* intermediate buffers */);
            layer_input = layer_output;
        }
    }
};

With Activation Checkpointing:

// Enable checkpointing for memory efficiency
BertTransformerLayer<__half> memory_efficient_layer(
    layer_id, batch_size, hidden_size, num_heads,
    intermediate_size, seq_length,
    dropout, dropout, layer_norm_eps,
    true,  // pre_layer_norm
    gemm_algos,
    true,  // attn_dropout_checkpoint - save dropout mask
    true,  // normalize_invertible - recompute activations
    true,  // gelu_checkpoint - save GELU input
    false  // stochastic_mode
);

// Reduced memory usage but slightly slower
// Saves: ~2-3× activation memory at cost of ~10-15% speed

Workspace Management:

unsigned compute_workspace(int batch, int seq, int hidden,
                          int intermediate, int heads) {
    // Base activations
    size_t base = 4 * batch * seq * hidden;

    // Training extras
    size_t training_extra = 2 * batch * seq * hidden;

    // Max of intermediate or attention
    size_t compute = std::max(
        batch * seq * intermediate,
        2 * batch * heads * seq * seq);

    // Optional GELU checkpoint
    size_t gelu = 2 * batch * seq * intermediate;

    return (base + training_extra + compute + gelu) * sizeof(__half);
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment