Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sgl project Sglang Common Extension Registration

From Leeroopedia


Knowledge Sources
Domains Kernel Infrastructure, PyTorch Extension
Last Updated 2026-02-10 00:00 GMT

Overview

Central PyTorch C++ extension registration file that exposes all CUDA kernel operations as torch library functions for the NVIDIA GPU backend.

Description

common_extension.cc is the primary extension registration file for the sgl-kernel NVIDIA CUDA backend. It uses the TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) macro to register all kernel operations with PyTorch's operator dispatch system.

The file defines function signatures using m.def() and binds CUDA implementations using m.impl() with the torch::kCUDA dispatch key. Operations are organized into the following categories:

  • Allreduce/Communication: get_graph_buffer_ipc_meta, register_graph_buffers, init_custom_ar, all_reduce, mscclpp_init_context, mscclpp_allreduce
  • Attention: merge_state, merge_state_v2, cutlass_mla_decode, cutlass_mla_get_workspace_size
  • Elementwise: rmsnorm, fused_add_rmsnorm, gemma_rmsnorm, silu_and_mul, gelu_tanh_and_mul, gelu_and_mul
  • Rotary Embeddings: apply_rope_pos_ids_cos_sin_cache, rotary_embedding
  • Quantization: downcast_fp8, sgl_per_token_group_quant_8bit, sgl_per_tensor_quant_fp8, scaled_fp4_quant
  • GEMM: awq_dequantize, int8_scaled_mm, fp8_scaled_mm, fp8_blockwise_scaled_mm, cutlass_scaled_fp4_mm
  • MoE: fast_topk, fast_topk_transform_fused
  • Sampling: various sampling operations

The m.def/m.impl pattern enables torch.compile compatibility by separating the operator schema definition from its dispatch-key-specific implementation.

Usage

This file is compiled as part of the sgl-kernel extension build. It is loaded automatically when the sgl_kernel Python package is imported. Developers adding new CUDA kernels must register them here to make them accessible from Python via torch.ops.sgl_kernel.

Code Reference

Source Location

Signature

TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
  // allreduce
  m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
  m.def("register_graph_buffers", &register_graph_buffers);
  m.def("init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
        "int rank, bool full_nvlink) -> int");
  m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
  m.def("all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
        "int reg_buffer_sz_bytes) -> ()");
  m.impl("all_reduce", torch::kCUDA, &all_reduce);

  // attention
  m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, "
        "Tensor! v_merged, Tensor! s_merged) -> ()");
  m.impl("merge_state", torch::kCUDA, &merge_state);

  // elementwise
  m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, "
        "float eps, bool enable_pdl) -> ()");
  m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
  m.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
  // ... (614 lines total)
}

Import

#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/all.h>
#include <torch/library.h>
#include "sgl_kernel_ops.h"

I/O Contract

Inputs

Name Type Required Description
sgl_kernel_ops.h Header Yes Declares all kernel function prototypes referenced in registration
torch/library.h Header Yes Provides TORCH_LIBRARY_FRAGMENT macro for operator registration

Outputs

Name Type Description
sgl_kernel library fragment PyTorch Dispatch Table All registered operations become callable via torch.ops.sgl_kernel.*

Usage Examples

Registering a New CUDA Kernel

// Define the operator schema
m.def("my_new_op(Tensor input, Tensor! output, float scale) -> ()");
// Bind the CUDA implementation
m.impl("my_new_op", torch::kCUDA, &my_new_op);

Calling from Python

import torch
import sgl_kernel

# Operations are accessible via torch.ops
torch.ops.sgl_kernel.rmsnorm(output, input, weight, eps=1e-5, enable_pdl=False)
torch.ops.sgl_kernel.silu_and_mul(out, input)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment