Implementation:Sgl project Sglang SM100 MLA Tile Scheduler
| Knowledge Sources | |
|---|---|
| Domains | GPU Kernel, Attention |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Tile scheduling strategies for distributing MLA (Multi-head Latent Attention) attention work across NVIDIA SM100 GPU thread blocks.
Description
sm100_mla_tile_scheduler.hpp provides two tile scheduling strategies within the cutlass::fmha::kernel namespace for distributing MLA attention computation across GPU streaming multiprocessors:
Sm100MlaIndividualTileScheduler assigns one tile per thread block in a simple 1:1 mapping. The grid dimensions are set to (cluster_shape_x, batch_size, split_kv). Each thread block processes exactly one tile and then exits (the operator++ sets valid_ = false). Block coordinates are derived directly from blockIdx.
Sm100MlaPersistentTileScheduler uses persistent kernels where thread blocks loop over multiple tiles using atomic counters for dynamic work distribution. The grid size is capped at the number of available SMs (hw_info.sm_count), and each thread block advances through tiles by incrementing block_idx += gridDim.x. Work items are decomposed using FastDivmod operations on the linear block index to extract (m_block, batch, kv_split) coordinates. This scheduler queries the device SM count via KernelHardwareInfo::query_device_multiprocessor_count for optimal grid sizing.
Usage
Use the Individual scheduler for simple workloads with uniform sequence lengths. Use the Persistent scheduler for better load balancing with variable-length sequences in batched decode, where some batches may require significantly more tiles than others.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp
- Lines: 1-160
Signature
namespace cutlass::fmha::kernel {
struct Sm100MlaIndividualTileScheduler {
struct Params { dim3 grid; };
CUTLASS_DEVICE Sm100MlaIndividualTileScheduler(Params const&);
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv);
static dim3 get_grid_shape(Params const& params);
CUTLASS_DEVICE bool is_valid();
CUTLASS_DEVICE auto get_block_coord();
CUTLASS_DEVICE Sm100MlaIndividualTileScheduler& operator++();
};
struct Sm100MlaPersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_split_kv;
KernelHardwareInfo hw_info;
};
CUTLASS_DEVICE Sm100MlaPersistentTileScheduler(Params const& params);
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv);
static dim3 get_grid_shape(Params const& params);
CUTLASS_DEVICE bool is_valid();
CUTLASS_DEVICE auto get_block_coord();
CUTLASS_DEVICE Sm100MlaPersistentTileScheduler& operator++();
};
} // namespace cutlass::fmha::kernel
Import
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| problem_shape | ProblemShape | Yes | Tuple describing the problem dimensions including batch size |
| hw_info | KernelHardwareInfo | Yes | Hardware info including SM count and device ID |
| cluster_shape | ClusterShape | Yes | Thread block cluster dimensions |
| split_kv | int | Yes | Number of KV sequence splits for split-K parallelism |
Outputs
| Name | Type | Description |
|---|---|---|
| grid | dim3 | CUDA grid dimensions for kernel launch |
| block_coord | tuple | (m_block, 0, batch_idx, kv_split_idx) coordinate for current thread block's work tile |
Usage Examples
// Individual scheduler: one tile per block
using Scheduler = Sm100MlaIndividualTileScheduler;
auto params = Scheduler::to_underlying_arguments(
problem_shape, hw_info, cluster_shape, split_kv);
dim3 grid = Scheduler::get_grid_shape(params);
// Persistent scheduler: blocks loop over tiles
using Scheduler = Sm100MlaPersistentTileScheduler;
auto params = Scheduler::to_underlying_arguments(
problem_shape, hw_info, cluster_shape, split_kv);
dim3 grid = Scheduler::get_grid_shape(params);
// Grid is capped at SM count for persistent execution