Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sgl project Sglang CPU Mamba Conv1D

From Leeroopedia


Knowledge Sources
Domains State Space Models, CPU Compute
Last Updated 2026-02-10 00:00 GMT

Overview

Implements CPU-optimized 1D causal convolution kernels for Mamba state-space models, including convolution state management and fused convolution with SiLU activation.

Description

conv.cpp provides the CPU implementation of causal 1D convolution used in Mamba state-space model architectures. Mamba models use causal 1D convolutions (typically with width=4) as a core building block, and this implementation fuses the convolution, bias addition, and SiLU activation into a single kernel to minimize memory traffic.

The file contains two main components:

1. Convolution State Management (update_conv_state)

The update_conv_state function manages the sliding window convolution state:

  • Shifts existing state entries left by seqlen positions
  • Inserts new input tokens into the rightmost positions
  • Handles the has_initial_states flag to distinguish between fresh initialization and ongoing updates
  • Uses copy_stub for SIMD-vectorized data movement (with zero-padding for uninitialized entries)

2. Fused Convolution Micro-kernel (tinygemm_kernel)

The tinygemm_kernel struct template expresses the 1D convolution as a small GEMM:

  • Input A: shape [M, BLOCK_N]
  • Weight B: shape [BLOCK_N, K] in VNNI-packed format [K/2, BLOCK_N, 2]
  • Output C: shape [M, BLOCK_N]

The AVX-512 BFloat16 specialization uses:

  • __m512bh vector types for BF16 AMX-compatible computation
  • MM512_PACK_A macro for packing activation pairs into VNNI format
  • MM512_LOAD_A macro for handling the boundary condition where initial tokens must use conv_states instead of input
  • set_conv_states lambda for loading convolution state at negative time indices

Template parameters control optional behavior:

  • K: Convolution kernel width (typically 4)
  • BLOCK_N: Channel block size
  • has_bias: Whether to add bias after convolution
  • has_silu: Whether to apply SiLU activation after convolution

Usage

This kernel is invoked during Mamba model inference on CPU. It handles both the first-token case (where convolution states may need initialization) and subsequent tokens (where states are shifted and updated).

Code Reference

Source Location

Signature

// SIMD-vectorized copy (with zero-fill for nullptr)
template <typename scalar_t>
inline void copy_stub(
    scalar_t* __restrict__ y,
    const scalar_t* __restrict__ x,
    int64_t size);

// Sliding window convolution state update
template <typename scalar_t>
void inline update_conv_state(
    scalar_t* __restrict__ conv_states,
    const scalar_t* __restrict__ input,
    int64_t width,
    int64_t dim,
    int64_t seqlen,
    bool has_initial_states);

// Fused convolution micro-kernel (generic template)
template <typename scalar_t, int K, int BLOCK_N, bool has_bias, bool has_silu>
struct tinygemm_kernel {
  static inline void apply(
      const scalar_t* __restrict__ A,
      const scalar_t* __restrict__ B,
      scalar_t* __restrict__ C,
      const scalar_t* __restrict__ bias,
      const scalar_t* __restrict__ conv_states,
      bool has_initial_state,
      int64_t M,
      int64_t lda,
      bool is_first_token);
};

// AVX-512 BFloat16 specialization
template <int K, int BLOCK_N, bool has_bias, bool has_silu>
struct tinygemm_kernel<at::BFloat16, K, BLOCK_N, has_bias, has_silu> {
  static inline void apply(
      const at::BFloat16* __restrict__ A,
      const at::BFloat16* __restrict__ B,
      at::BFloat16* __restrict__ C,
      const at::BFloat16* __restrict__ bias,
      const at::BFloat16* __restrict__ conv_states,
      bool has_initial_state,
      int64_t M,
      int64_t lda,
      bool is_first_token);
};

Import

#include "common.h"
#include "gemm.h"
#include "vec.h"

I/O Contract

Inputs

Name Type Required Description
A scalar_t* Yes Input activation tensor, shape [M, BLOCK_N] (or leading dimension lda)
B scalar_t* Yes Convolution weight in VNNI-packed format [K/2, BLOCK_N, 2]
conv_states scalar_t* Yes Convolution state buffer, shape [width-1, dim]
bias scalar_t* Conditional Bias vector of length BLOCK_N (required when has_bias=true)
input scalar_t* Yes (for update) New input tokens for state update
width int64_t Yes Convolution kernel width (typically 4)
dim int64_t Yes Channel dimension size
seqlen int64_t Yes Number of new input tokens
has_initial_states bool Yes Whether conv_states contains valid initial data
M int64_t Yes Number of output positions (sequence length)
lda int64_t Yes Leading dimension of input/output tensors
is_first_token bool Yes Whether this is the first token (uses conv_states for padding)

Outputs

Name Type Description
C scalar_t* Convolution output, shape [M, BLOCK_N], with optional bias and SiLU applied
conv_states scalar_t* Updated convolution state buffer (shifted and filled with new tokens)

Usage Examples

Update Convolution State

// Shift convolution state and insert new tokens
update_conv_state<at::BFloat16>(
    conv_state_ptr,       // conv_states: [width-1, dim]
    new_input_ptr,        // input: new tokens
    /*width=*/4,          // convolution kernel width
    /*dim=*/2048,         // channel dimension
    /*seqlen=*/1,         // number of new tokens
    /*has_initial_states=*/true);

Fused Convolution + SiLU

// Apply 1D causal convolution with bias and SiLU activation
tinygemm_kernel<at::BFloat16, /*K=*/4, /*BLOCK_N=*/32,
                /*has_bias=*/true, /*has_silu=*/true>::apply(
    input_ptr,           // A: input [M, BLOCK_N]
    packed_weight_ptr,   // B: VNNI weight [K/2, BLOCK_N, 2]
    output_ptr,          // C: output [M, BLOCK_N]
    bias_ptr,            // bias: [BLOCK_N]
    conv_state_ptr,      // conv_states: for first-token padding
    /*has_initial_state=*/true,
    /*M=*/seq_len,
    /*lda=*/hidden_dim,
    /*is_first_token=*/true);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment