Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sgl project Sglang Kernel Ops Header

From Leeroopedia
Revision as of 16:40, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Sgl_project_Sglang_Kernel_Ops_Header.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains GPU Kernels, C++ Headers, LLM Inference
Last Updated 2026-02-10 00:00 GMT

Overview

Master C++ header declaring all function signatures for the sgl_kernel CUDA/HIP operations across every subsystem of the SGLang kernel library.

Description

sgl_kernel_ops.h is the single authoritative declaration of all kernel operation interfaces in the SGLang kernel library. It serves as the contract between the C++ kernel implementations and the PyTorch extension registration layer.

The header is organized into clearly delineated sections by subsystem:

  • Allreduce: Custom allreduce operations with conditional compilation for ROCm (using HIP-based custom AR and quick allreduce) vs CUDA (using custom AR and MSCCL++). The ROCm path includes init_custom_ar, all_reduce_reg, all_reduce_unreg, dispose, and quick reduce operations. The CUDA path includes init_custom_ar, all_reduce, and MSCCL++ context initialization.
  • Attention: Merge-state functions for multi-head attention (merge_state, merge_state_v2), CUTLASS MLA decode (cutlass_mla_decode), and workspace size calculation (cutlass_mla_get_workspace_size).
  • Elementwise: RMSNorm variants (standard and Gemma), fused add-RMSNorm, activation functions (silu_and_mul, gelu_tanh_and_mul, gelu_and_mul), rotary embedding (apply_rope_pos_ids_cos_sin_cache, rotary_embedding), FP8 downcast, CE-free GPU copy, MLA concatenation, and fast top-k operations.
  • GEMM: AWQ dequantization, CUTLASS FP4 scaled MM, INT8/FP8 scaled MM, FP8 blockwise scaled MM, batched FP8 MM, DeepSeek V3 router/fused GEMM, GPTQ Marlin GEMM, per-token quantization routines, and QServe W4A8 GEMM.
  • MoE: Block alignment (moe_align_block_size), top-k softmax/sigmoid gating, fused gate operations (including Kimi K2), FP8 blockwise grouped MM, MoE input preparation, row shuffling, fused QK norm+RoPE, CUTLASS FP4 group MM, and scaled FP4 expert quantization.
  • Speculative Decoding: Tree-based speculative sampling, greedy verification, tree mask reconstruction, and segment packbits.
  • KV Cache I/O: Comprehensive set of KV cache transfer functions for per-layer and all-layer operations across various memory layouts (page-first, page-head, layer-first), including MLA variants and direct transfer.
  • Memory: Weak reference tensors and KV cache store operations.
  • Sampling: FlashInfer-derived sampling kernels including min-p, top-k renorm, top-p renorm, top-k-top-p combined, and logit masking.
  • Flash Attention: FA2 sparse forward and variable-length sparse forward, vertical-slash index conversion.
  • Grammar: Token bitmask application for constrained decoding.
  • GGUF: Dequantization and matrix multiplication for GGUF quantized formats.
  • Mamba: Causal conv1d forward and update for state-space models.
  • Expert Specialization: ES-specific FP8 grouped MM and SM100 MXFP8 operations.
  • Hadamard Transform: Fast Hadamard transform variants (standard, 12N, 20N, 28N, 40N).
  • FlashMLA: MLA decoding metadata and forward KV-cache functions with FP8 support.

The header uses conditional compilation (#ifdef USE_ROCM) to select between CUDA and ROCm implementations, and defines utility macros (REGISTER_EXTENSION, TORCH_LIBRARY_EXPAND, CONCAT, STRINGIFY) for PyTorch extension registration.

Usage

This header is included by all C++ source files that implement kernel operations and by the PyTorch extension binding code. It should be used when adding new kernel operations to ensure they are properly declared and visible to the binding layer.

Code Reference

Source Location

Signature

// Allreduce (CUDA path)
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(fptr_t _fa);

// Attention
void merge_state(at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor const& q_pe,
    torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table,
    torch::Tensor const& workspace, double sm_scale, int64_t num_kv_splits = 1);

// Elementwise
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
void silu_and_mul(at::Tensor& out, at::Tensor& input);
void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
    at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, bool enable_pdl, ...);

// GEMM
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b,
    const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
    const c10::optional<torch::Tensor>& bias);
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
    torch::Tensor& b_q_weight, torch::Tensor& b_scales, ...);

// MoE
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, ...);
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& gating_output,
    bool renormalize, double moe_softcapping, const c10::optional<torch::Tensor>& correction_bias);

// KV Cache I/O
void transfer_kv_per_layer(const at::Tensor src_k, at::Tensor dst_k, const at::Tensor src_v, at::Tensor dst_v,
    const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size,
    int64_t block_quota, int64_t num_warps_per_block);

// Sampling (FlashInfer)
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor output, ...);
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor output, ...);

Import

#include "sgl_kernel_ops.h"

I/O Contract

Inputs

Name Type Required Description
Various tensors at::Tensor / torch::Tensor Yes Input tensors vary by function; all are CUDA tensors
Scalar parameters int64_t, double, bool Varies Configuration values such as epsilon, scale factors, flags
Optional tensors std::optional<torch::Tensor> No Optional inputs like bias, correction factors

Outputs

Name Type Description
Return tensors torch::Tensor Some functions return new tensors (e.g., GEMM results)
In-place updates void Many functions modify output tensors passed by reference
Handles fptr_t (int64_t) Opaque handles for allreduce contexts

Usage Examples

// Example: Calling rmsnorm from C++
at::Tensor output = torch::empty_like(input);
rmsnorm(output, input, weight, 1e-6, /*enable_pdl=*/true);

// Example: FP8 scaled matrix multiplication
auto result = fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, torch::kBFloat16, /*bias=*/c10::nullopt);

// Example: MoE top-k softmax gating
topk_softmax(topk_weights, topk_ids, gating_output, /*renormalize=*/true, /*moe_softcapping=*/0.0, /*correction_bias=*/c10::nullopt);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment