Implementation:Sgl project Sglang Common Extension Registration
| Knowledge Sources | |
|---|---|
| Domains | Kernel Infrastructure, PyTorch Extension |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Central PyTorch C++ extension registration file that exposes all CUDA kernel operations as torch library functions for the NVIDIA GPU backend.
Description
common_extension.cc is the primary extension registration file for the sgl-kernel NVIDIA CUDA backend. It uses the TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) macro to register all kernel operations with PyTorch's operator dispatch system.
The file defines function signatures using m.def() and binds CUDA implementations using m.impl() with the torch::kCUDA dispatch key. Operations are organized into the following categories:
- Allreduce/Communication: get_graph_buffer_ipc_meta, register_graph_buffers, init_custom_ar, all_reduce, mscclpp_init_context, mscclpp_allreduce
- Attention: merge_state, merge_state_v2, cutlass_mla_decode, cutlass_mla_get_workspace_size
- Elementwise: rmsnorm, fused_add_rmsnorm, gemma_rmsnorm, silu_and_mul, gelu_tanh_and_mul, gelu_and_mul
- Rotary Embeddings: apply_rope_pos_ids_cos_sin_cache, rotary_embedding
- Quantization: downcast_fp8, sgl_per_token_group_quant_8bit, sgl_per_tensor_quant_fp8, scaled_fp4_quant
- GEMM: awq_dequantize, int8_scaled_mm, fp8_scaled_mm, fp8_blockwise_scaled_mm, cutlass_scaled_fp4_mm
- MoE: fast_topk, fast_topk_transform_fused
- Sampling: various sampling operations
The m.def/m.impl pattern enables torch.compile compatibility by separating the operator schema definition from its dispatch-key-specific implementation.
Usage
This file is compiled as part of the sgl-kernel extension build. It is loaded automatically when the sgl_kernel Python package is imported. Developers adding new CUDA kernels must register them here to make them accessible from Python via torch.ops.sgl_kernel.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/common_extension.cc
- Lines: 1-614
Signature
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// allreduce
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", ®ister_graph_buffers);
m.def("init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool full_nvlink) -> int");
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
m.def("all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
m.impl("all_reduce", torch::kCUDA, &all_reduce);
// attention
m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, "
"Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state", torch::kCUDA, &merge_state);
// elementwise
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, "
"float eps, bool enable_pdl) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
m.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
// ... (614 lines total)
}
Import
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/all.h>
#include <torch/library.h>
#include "sgl_kernel_ops.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| sgl_kernel_ops.h | Header | Yes | Declares all kernel function prototypes referenced in registration |
| torch/library.h | Header | Yes | Provides TORCH_LIBRARY_FRAGMENT macro for operator registration |
Outputs
| Name | Type | Description |
|---|---|---|
| sgl_kernel library fragment | PyTorch Dispatch Table | All registered operations become callable via torch.ops.sgl_kernel.* |
Usage Examples
Registering a New CUDA Kernel
// Define the operator schema
m.def("my_new_op(Tensor input, Tensor! output, float scale) -> ()");
// Bind the CUDA implementation
m.impl("my_new_op", torch::kCUDA, &my_new_op);
Calling from Python
import torch
import sgl_kernel
# Operations are accessible via torch.ops
torch.ops.sgl_kernel.rmsnorm(output, input, weight, eps=1e-5, enable_pdl=False)
torch.ops.sgl_kernel.silu_and_mul(out, input)