Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sgl project Sglang CPU Mamba FLA

From Leeroopedia


Knowledge Sources
Domains Machine Learning, CPU Kernels
Last Updated 2026-02-10 00:00 GMT

Overview

Implements the CPU kernel for the chunk-based gated delta rule used in Flash Linear Attention (FLA) models, providing an efficient chunked attention mechanism for linear recurrent architectures such as Mamba-2.

Description

This file provides the core chunk_gated_delta_rule_kernel_impl function that processes sequences in fixed-size chunks (default chunk_size=64) with inputs including query, key, value tensors, gating values, and beta values. The algorithm operates in multiple phases:

  • Optional L2 normalization of Q and K vectors with fused computation to reduce the number of parallel loops to just 4.
  • Padding and rearranging input data into chunk-aligned buffers (q_pad, k_pad, v_pad).
  • Computing decay masks and gated attention within chunks.
  • Accumulating recurrent states across chunks with the delta rule update.

The kernel uses extensive buffer management via the THREAD_BUFFER_ALLOC macro for per-thread scratch space. It supports grouped query attention via the head_group parameter and variable-length sequence processing via cu_seqlens_ptr. SIMD vectorization is used throughout with at::vec::Vectorized and float32 intermediate computation for numerical stability.

Two public API functions are exposed: fused_sigmoid_gating_delta_rule_update_cpu for the full gated delta rule with sigmoid gating, and fused_gdn_gating_cpu for computing the gating values (decay and beta) from A_log, a, b, and dt_bias parameters.

Usage

Use this kernel when performing CPU inference for models that employ Flash Linear Attention or gated delta rule mechanisms, such as Mamba-2 and other hybrid models combining attention with linear recurrence. It is invoked through the fused_sigmoid_gating_delta_rule_update_cpu function with batched Q, K, V tensors in 4D layout [seq_len, batch_size, num_heads, head_dim].

Code Reference

Source Location

Signature

// Main public API
at::Tensor fused_sigmoid_gating_delta_rule_update_cpu(
    const at::Tensor& A_log,
    const at::Tensor& dt_bias,
    const at::Tensor& q,
    const at::Tensor& k,
    const at::Tensor& v,
    const at::Tensor& a,
    const at::Tensor& b,
    at::Tensor& initial_state_source,
    const at::Tensor& initial_state_indices,
    const at::Tensor& cu_seqlens,
    bool use_qk_l2norm_in_kernel,
    double softplus_beta = 1.0,
    double softplus_threshold = 20.0);

std::tuple<at::Tensor, at::Tensor>
fused_gdn_gating_cpu(
    const at::Tensor& A_log,
    const at::Tensor& a,
    const at::Tensor& b,
    const at::Tensor& dt_bias);

// Internal kernel
template <typename scalar_t, int64_t chunk_size = 64>
void chunk_gated_delta_rule_kernel_impl(
    scalar_t* __restrict__ out,
    float* __restrict__ final_state_data,
    const scalar_t* __restrict__ q_orig,
    const scalar_t* __restrict__ k_orig,
    const scalar_t* __restrict__ v_orig,
    const float* __restrict__ g_orig,
    const scalar_t* __restrict__ b_orig,
    const int32_t* __restrict__ cu_seqlens_ptr,
    float* __restrict__ buff,
    scalar_t* __restrict__ reduced_buff,
    scalar_t* __restrict__ thread_buff,
    const int32_t* __restrict__ chunk_offsets_ptr,
    const int32_t* __restrict__ chunk_indices_ptr,
    bool use_qk_l2norm_in_kernel,
    /* ... dimensional and stride parameters ... */);

Import

#include "common.h"
#include "gemm.h"
#include "vec.h"
#include "vec_pack.h"

I/O Contract

Inputs

Name Type Required Description
A_log at::Tensor [num_v_heads] Yes Log-space decay parameters per attention head (float32)
dt_bias at::Tensor [num_v_heads] Yes Bias for delta time computation per head
q at::Tensor [seq_len, batch, num_heads, head_dim] Yes Query tensor in 4D layout
k at::Tensor [seq_len, batch, num_heads, head_dim] Yes Key tensor in 4D layout
v at::Tensor [seq_len, batch, v_num_heads, v_head_dim] Yes Value tensor in 4D layout
a at::Tensor [batch, v_num_heads] Yes Gating parameter a (sigmoid input)
b at::Tensor [batch, v_num_heads] Yes Beta values for gating
initial_state_source at::Tensor [N, v_num_heads, head_dim, v_head_dim] Yes Mutable recurrent state buffer (float32)
initial_state_indices at::Tensor [batch] Yes Indices into initial_state_source per batch element (int32)
cu_seqlens at::Tensor [batch + 1] Yes Cumulative sequence lengths for variable-length batching (int32)
use_qk_l2norm_in_kernel bool Yes Whether to apply L2 normalization to Q and K
softplus_beta double No Beta parameter for softplus activation (default 1.0)
softplus_threshold double No Threshold for softplus linearization (default 20.0)

Outputs

Name Type Description
output at::Tensor [batch, seq_len, v_num_heads, v_head_dim] Attention output tensor from the gated delta rule computation

Usage Examples

// Invoke the fused gated delta rule for Mamba-2 style inference
at::Tensor output = fused_sigmoid_gating_delta_rule_update_cpu(
    A_log,                    // [num_v_heads] float32 decay params
    dt_bias,                  // [num_v_heads] delta time bias
    q, k, v,                 // query, key, value tensors
    a, b,                    // gating parameters
    initial_state_source,    // mutable recurrent state
    initial_state_indices,   // batch-to-state mapping
    cu_seqlens,              // cumulative sequence lengths
    /*use_qk_l2norm=*/true,
    /*softplus_beta=*/1.0,
    /*softplus_threshold=*/20.0);

// Compute gating values only
auto [gating, beta] = fused_gdn_gating_cpu(A_log, a, b, dt_bias);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment