Implementation:Sgl project Sglang CPU Mamba FLA
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, CPU Kernels |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implements the CPU kernel for the chunk-based gated delta rule used in Flash Linear Attention (FLA) models, providing an efficient chunked attention mechanism for linear recurrent architectures such as Mamba-2.
Description
This file provides the core chunk_gated_delta_rule_kernel_impl function that processes sequences in fixed-size chunks (default chunk_size=64) with inputs including query, key, value tensors, gating values, and beta values. The algorithm operates in multiple phases:
- Optional L2 normalization of Q and K vectors with fused computation to reduce the number of parallel loops to just 4.
- Padding and rearranging input data into chunk-aligned buffers (q_pad, k_pad, v_pad).
- Computing decay masks and gated attention within chunks.
- Accumulating recurrent states across chunks with the delta rule update.
The kernel uses extensive buffer management via the THREAD_BUFFER_ALLOC macro for per-thread scratch space. It supports grouped query attention via the head_group parameter and variable-length sequence processing via cu_seqlens_ptr. SIMD vectorization is used throughout with at::vec::Vectorized and float32 intermediate computation for numerical stability.
Two public API functions are exposed: fused_sigmoid_gating_delta_rule_update_cpu for the full gated delta rule with sigmoid gating, and fused_gdn_gating_cpu for computing the gating values (decay and beta) from A_log, a, b, and dt_bias parameters.
Usage
Use this kernel when performing CPU inference for models that employ Flash Linear Attention or gated delta rule mechanisms, such as Mamba-2 and other hybrid models combining attention with linear recurrence. It is invoked through the fused_sigmoid_gating_delta_rule_update_cpu function with batched Q, K, V tensors in 4D layout [seq_len, batch_size, num_heads, head_dim].
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/cpu/mamba/fla.cpp
- Lines: 1-1341
Signature
// Main public API
at::Tensor fused_sigmoid_gating_delta_rule_update_cpu(
const at::Tensor& A_log,
const at::Tensor& dt_bias,
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
const at::Tensor& a,
const at::Tensor& b,
at::Tensor& initial_state_source,
const at::Tensor& initial_state_indices,
const at::Tensor& cu_seqlens,
bool use_qk_l2norm_in_kernel,
double softplus_beta = 1.0,
double softplus_threshold = 20.0);
std::tuple<at::Tensor, at::Tensor>
fused_gdn_gating_cpu(
const at::Tensor& A_log,
const at::Tensor& a,
const at::Tensor& b,
const at::Tensor& dt_bias);
// Internal kernel
template <typename scalar_t, int64_t chunk_size = 64>
void chunk_gated_delta_rule_kernel_impl(
scalar_t* __restrict__ out,
float* __restrict__ final_state_data,
const scalar_t* __restrict__ q_orig,
const scalar_t* __restrict__ k_orig,
const scalar_t* __restrict__ v_orig,
const float* __restrict__ g_orig,
const scalar_t* __restrict__ b_orig,
const int32_t* __restrict__ cu_seqlens_ptr,
float* __restrict__ buff,
scalar_t* __restrict__ reduced_buff,
scalar_t* __restrict__ thread_buff,
const int32_t* __restrict__ chunk_offsets_ptr,
const int32_t* __restrict__ chunk_indices_ptr,
bool use_qk_l2norm_in_kernel,
/* ... dimensional and stride parameters ... */);
Import
#include "common.h"
#include "gemm.h"
#include "vec.h"
#include "vec_pack.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| A_log | at::Tensor [num_v_heads] | Yes | Log-space decay parameters per attention head (float32) |
| dt_bias | at::Tensor [num_v_heads] | Yes | Bias for delta time computation per head |
| q | at::Tensor [seq_len, batch, num_heads, head_dim] | Yes | Query tensor in 4D layout |
| k | at::Tensor [seq_len, batch, num_heads, head_dim] | Yes | Key tensor in 4D layout |
| v | at::Tensor [seq_len, batch, v_num_heads, v_head_dim] | Yes | Value tensor in 4D layout |
| a | at::Tensor [batch, v_num_heads] | Yes | Gating parameter a (sigmoid input) |
| b | at::Tensor [batch, v_num_heads] | Yes | Beta values for gating |
| initial_state_source | at::Tensor [N, v_num_heads, head_dim, v_head_dim] | Yes | Mutable recurrent state buffer (float32) |
| initial_state_indices | at::Tensor [batch] | Yes | Indices into initial_state_source per batch element (int32) |
| cu_seqlens | at::Tensor [batch + 1] | Yes | Cumulative sequence lengths for variable-length batching (int32) |
| use_qk_l2norm_in_kernel | bool | Yes | Whether to apply L2 normalization to Q and K |
| softplus_beta | double | No | Beta parameter for softplus activation (default 1.0) |
| softplus_threshold | double | No | Threshold for softplus linearization (default 20.0) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | at::Tensor [batch, seq_len, v_num_heads, v_head_dim] | Attention output tensor from the gated delta rule computation |
Usage Examples
// Invoke the fused gated delta rule for Mamba-2 style inference
at::Tensor output = fused_sigmoid_gating_delta_rule_update_cpu(
A_log, // [num_v_heads] float32 decay params
dt_bias, // [num_v_heads] delta time bias
q, k, v, // query, key, value tensors
a, b, // gating parameters
initial_state_source, // mutable recurrent state
initial_state_indices, // batch-to-state mapping
cu_seqlens, // cumulative sequence lengths
/*use_qk_l2norm=*/true,
/*softplus_beta=*/1.0,
/*softplus_threshold=*/20.0);
// Compute gating values only
auto [gating, beta] = fused_gdn_gating_cpu(A_log, a, b, dt_bias);