Implementation:Unslothai Unsloth GEMM Kernel Tuning
| Knowledge Sources | |
|---|---|
| Domains | MoE, Kernel_Optimization |
| Last Updated | 2026-02-07 08:40 GMT |
Overview
Configuration dataclasses, pruning rules, and result tracking utilities for manual Triton kernel benchmarking and tuning of grouped GEMM operations.
Description
The tuning module defines the kernel configuration type system used throughout the MoE grouped GEMM subsystem. It provides KernelConfig as a base dataclass with block sizes, warp counts, and pipeline stages, plus specialized subclasses KernelConfigForward, KernelConfigBackward_dW, and KernelConfigBackward_dX with TMA-specific fields. KernelResult tracks benchmark timings with pandas DataFrame export, and TritonTuningContext provides safe error handling for kernel compilation failures.
Usage
Import the kernel config dataclasses when configuring grouped GEMM kernel parameters manually instead of using autotuning.
Code Reference
Source Location
- Repository: Unslothai_Unsloth
- File: unsloth/kernels/moe/grouped_gemm/kernels/tuning.py
- Lines: 1-277
Signature
@dataclass
class KernelConfig:
BLOCK_SIZE_M: int = 32
BLOCK_SIZE_N: int = 32
BLOCK_SIZE_K: int = 32
num_warps: int = 4
num_stages: int = 2
flatten: bool = True
permute_x: bool = False
permute_y: bool = False
fuse_mul_post: bool = False
use_tma_store: bool = False
@dataclass
class KernelConfigForward(KernelConfig):
use_tma_load_w: bool = False
use_tma_load_x: bool = False
@dataclass
class KernelConfigBackward_dW(KernelConfig):
use_tma_load_dy: bool = False
use_tma_load_x: bool = False
@dataclass
class KernelConfigBackward_dX(KernelConfig):
use_tma_load_dy: bool = False
use_tma_load_w: bool = False
@dataclass
class KernelResult:
torch_time: float
triton_time: float
speedup: float
kernel_config: KernelConfig
class TritonTuningContext:
def __init__(self, kernel_config: KernelConfig): ...
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback) -> bool: ...
def get_device_properties() -> DeviceProperties:
"""Get GPU hardware properties."""
def get_kernel_configs(...) -> tuple[list, list, list]:
"""Generate all valid forward/dW/dX configurations."""
Import
from unsloth.kernels.moe.grouped_gemm.kernels.tuning import (
KernelConfigForward,
KernelConfigBackward_dW,
KernelConfigBackward_dX,
KernelResult,
TritonTuningContext,
)
I/O Contract
Inputs (KernelConfigForward)
| Name | Type | Required | Description |
|---|---|---|---|
| BLOCK_SIZE_M | int | No | M dimension tile size (default: 32) |
| BLOCK_SIZE_N | int | No | N dimension tile size (default: 32) |
| BLOCK_SIZE_K | int | No | K dimension tile size (default: 32) |
| num_warps | int | No | Number of Triton warps (default: 4) |
| num_stages | int | No | Pipeline stages (default: 2) |
| permute_x | bool | No | Fuse input permutation (default: False) |
| permute_y | bool | No | Fuse output permutation (default: False) |
| use_tma_load_w | bool | No | TMA for weight loads (default: False) |
| use_tma_load_x | bool | No | TMA for input loads (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| KernelConfig instance | dataclass | Configuration object for kernel dispatch |
| KernelResult | dataclass | Benchmark timing with speedup factor |
Usage Examples
Manual Kernel Configuration
from unsloth.kernels.moe.grouped_gemm.kernels.tuning import (
KernelConfigForward,
KernelConfigBackward_dX,
KernelConfigBackward_dW,
)
# Create forward config
config_fwd = KernelConfigForward(
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=128,
BLOCK_SIZE_K=64,
num_warps=8,
num_stages=4,
permute_y=True,
)
# Create backward configs
config_dx = KernelConfigBackward_dX(
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=64,
BLOCK_SIZE_K=256,
)
config_dw = KernelConfigBackward_dW(
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=64,
BLOCK_SIZE_K=256,
)