Implementation:Unslothai Unsloth GEMM Autotuning Configs
| Knowledge Sources | |
|---|---|
| Domains | MoE, Kernel_Optimization, Triton_Kernels |
| Last Updated | 2026-02-07 08:40 GMT |
Overview
Concrete tool for generating and pruning Triton autotuning configuration search spaces for grouped GEMM forward and backward kernels.
Description
The autotuning module produces combinatorial triton.Config objects covering block sizes (M: 64-128, N: 64-256, K: 64-256), warp counts (4-8), and pipeline stages (3-5). Pruning functions filter invalid configurations by checking shared memory capacity, TMA/permute incompatibilities, and excessive block sizes relative to tokens per expert. This controls both autotuning quality and compilation time.
Usage
Import the config generation functions when defining Triton autotuned kernel wrappers. The pruning functions are passed to triton.autotune as prune_configs_by callbacks.
Code Reference
Source Location
- Repository: Unslothai_Unsloth
- File: unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py
- Lines: 1-441
Signature
def get_forward_configs(
BLOCK_M=None, BLOCK_N=None, BLOCK_K=None,
TMA_LOAD_X=None, TMA_LOAD_W=None, TMA_STORE=False,
num_warps=None, num_stages=None, num_ctas=1,
) -> List[triton.Config]:
"""Generate forward kernel autotuning configurations."""
def get_dX_kernel_configs(...) -> List[triton.Config]:
"""Generate dX backward kernel autotuning configurations."""
def get_dW_kernel_configs(...) -> List[triton.Config]:
"""Generate dW backward kernel autotuning configurations."""
def prune_kernel_configs_fwd(
configs: list[triton.Config], args, **kwargs
) -> list[triton.Config]:
"""Prune forward configs by SMEM, TMA, and block size constraints."""
def prune_dX_configs(
configs: List[triton.Config], args, **kwargs
) -> list[triton.Config]:
"""Prune dX backward configs."""
def prune_kernel_configs_backward_dW(
configs: list[triton.Config], args, **kwargs
) -> list[triton.Config]:
"""Prune dW backward configs."""
def estimate_smem_reqs(
num_stages: int, BLOCK_SIZE_M: int, BLOCK_SIZE_N: int,
BLOCK_SIZE_K: int, dtype: torch.dtype,
) -> int:
"""Estimate shared memory requirements for a kernel config."""
Import
from unsloth.kernels.moe.grouped_gemm.kernels.autotuning import (
get_forward_configs,
get_dX_kernel_configs,
get_dW_kernel_configs,
prune_kernel_configs_fwd,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| BLOCK_M/N/K | list[int] or None | No | Block size ranges (default: [64,128], [64,128,256], [64,128,256]) |
| num_warps | list[int] or None | No | Warp count options (default: [4, 8]) |
| num_stages | list[int] or None | No | Pipeline stage options (default: [3, 4, 5]) |
| TMA_LOAD_* | bool or None | No | TMA load flags (default: None, auto-detect) |
Outputs
| Name | Type | Description |
|---|---|---|
| configs | List[triton.Config] | List of valid triton.Config objects for autotuning |
Usage Examples
Using with Triton Autotune
import triton
from unsloth.kernels.moe.grouped_gemm.kernels.autotuning import (
get_forward_configs,
prune_kernel_configs_fwd,
)
# Create autotuned kernel wrapper
autotuned_kernel = triton.autotune(
configs=get_forward_configs(),
prune_configs_by={"early_config_prune": prune_kernel_configs_fwd},
key=["NUM_EXPERTS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
)(my_kernel)