Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Unslothai Unsloth GEMM Autotuning Configs

From Leeroopedia


Knowledge Sources
Domains MoE, Kernel_Optimization, Triton_Kernels
Last Updated 2026-02-07 08:40 GMT

Overview

Concrete tool for generating and pruning Triton autotuning configuration search spaces for grouped GEMM forward and backward kernels.

Description

The autotuning module produces combinatorial triton.Config objects covering block sizes (M: 64-128, N: 64-256, K: 64-256), warp counts (4-8), and pipeline stages (3-5). Pruning functions filter invalid configurations by checking shared memory capacity, TMA/permute incompatibilities, and excessive block sizes relative to tokens per expert. This controls both autotuning quality and compilation time.

Usage

Import the config generation functions when defining Triton autotuned kernel wrappers. The pruning functions are passed to triton.autotune as prune_configs_by callbacks.

Code Reference

Source Location

Signature

def get_forward_configs(
    BLOCK_M=None, BLOCK_N=None, BLOCK_K=None,
    TMA_LOAD_X=None, TMA_LOAD_W=None, TMA_STORE=False,
    num_warps=None, num_stages=None, num_ctas=1,
) -> List[triton.Config]:
    """Generate forward kernel autotuning configurations."""

def get_dX_kernel_configs(...) -> List[triton.Config]:
    """Generate dX backward kernel autotuning configurations."""

def get_dW_kernel_configs(...) -> List[triton.Config]:
    """Generate dW backward kernel autotuning configurations."""

def prune_kernel_configs_fwd(
    configs: list[triton.Config], args, **kwargs
) -> list[triton.Config]:
    """Prune forward configs by SMEM, TMA, and block size constraints."""

def prune_dX_configs(
    configs: List[triton.Config], args, **kwargs
) -> list[triton.Config]:
    """Prune dX backward configs."""

def prune_kernel_configs_backward_dW(
    configs: list[triton.Config], args, **kwargs
) -> list[triton.Config]:
    """Prune dW backward configs."""

def estimate_smem_reqs(
    num_stages: int, BLOCK_SIZE_M: int, BLOCK_SIZE_N: int,
    BLOCK_SIZE_K: int, dtype: torch.dtype,
) -> int:
    """Estimate shared memory requirements for a kernel config."""

Import

from unsloth.kernels.moe.grouped_gemm.kernels.autotuning import (
    get_forward_configs,
    get_dX_kernel_configs,
    get_dW_kernel_configs,
    prune_kernel_configs_fwd,
)

I/O Contract

Inputs

Name Type Required Description
BLOCK_M/N/K list[int] or None No Block size ranges (default: [64,128], [64,128,256], [64,128,256])
num_warps list[int] or None No Warp count options (default: [4, 8])
num_stages list[int] or None No Pipeline stage options (default: [3, 4, 5])
TMA_LOAD_* bool or None No TMA load flags (default: None, auto-detect)

Outputs

Name Type Description
configs List[triton.Config] List of valid triton.Config objects for autotuning

Usage Examples

Using with Triton Autotune

import triton
from unsloth.kernels.moe.grouped_gemm.kernels.autotuning import (
    get_forward_configs,
    prune_kernel_configs_fwd,
)

# Create autotuned kernel wrapper
autotuned_kernel = triton.autotune(
    configs=get_forward_configs(),
    prune_configs_by={"early_config_prune": prune_kernel_configs_fwd},
    key=["NUM_EXPERTS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
)(my_kernel)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment