Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Unslothai Unsloth GEMM Kernel Tuning

From Leeroopedia
Revision as of 17:02, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Unslothai_Unsloth_GEMM_Kernel_Tuning.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains MoE, Kernel_Optimization
Last Updated 2026-02-07 08:40 GMT

Overview

Configuration dataclasses, pruning rules, and result tracking utilities for manual Triton kernel benchmarking and tuning of grouped GEMM operations.

Description

The tuning module defines the kernel configuration type system used throughout the MoE grouped GEMM subsystem. It provides KernelConfig as a base dataclass with block sizes, warp counts, and pipeline stages, plus specialized subclasses KernelConfigForward, KernelConfigBackward_dW, and KernelConfigBackward_dX with TMA-specific fields. KernelResult tracks benchmark timings with pandas DataFrame export, and TritonTuningContext provides safe error handling for kernel compilation failures.

Usage

Import the kernel config dataclasses when configuring grouped GEMM kernel parameters manually instead of using autotuning.

Code Reference

Source Location

Signature

@dataclass
class KernelConfig:
    BLOCK_SIZE_M: int = 32
    BLOCK_SIZE_N: int = 32
    BLOCK_SIZE_K: int = 32
    num_warps: int = 4
    num_stages: int = 2
    flatten: bool = True
    permute_x: bool = False
    permute_y: bool = False
    fuse_mul_post: bool = False
    use_tma_store: bool = False

@dataclass
class KernelConfigForward(KernelConfig):
    use_tma_load_w: bool = False
    use_tma_load_x: bool = False

@dataclass
class KernelConfigBackward_dW(KernelConfig):
    use_tma_load_dy: bool = False
    use_tma_load_x: bool = False

@dataclass
class KernelConfigBackward_dX(KernelConfig):
    use_tma_load_dy: bool = False
    use_tma_load_w: bool = False

@dataclass
class KernelResult:
    torch_time: float
    triton_time: float
    speedup: float
    kernel_config: KernelConfig

class TritonTuningContext:
    def __init__(self, kernel_config: KernelConfig): ...
    def __enter__(self): ...
    def __exit__(self, exc_type, exc_value, traceback) -> bool: ...

def get_device_properties() -> DeviceProperties:
    """Get GPU hardware properties."""

def get_kernel_configs(...) -> tuple[list, list, list]:
    """Generate all valid forward/dW/dX configurations."""

Import

from unsloth.kernels.moe.grouped_gemm.kernels.tuning import (
    KernelConfigForward,
    KernelConfigBackward_dW,
    KernelConfigBackward_dX,
    KernelResult,
    TritonTuningContext,
)

I/O Contract

Inputs (KernelConfigForward)

Name Type Required Description
BLOCK_SIZE_M int No M dimension tile size (default: 32)
BLOCK_SIZE_N int No N dimension tile size (default: 32)
BLOCK_SIZE_K int No K dimension tile size (default: 32)
num_warps int No Number of Triton warps (default: 4)
num_stages int No Pipeline stages (default: 2)
permute_x bool No Fuse input permutation (default: False)
permute_y bool No Fuse output permutation (default: False)
use_tma_load_w bool No TMA for weight loads (default: False)
use_tma_load_x bool No TMA for input loads (default: False)

Outputs

Name Type Description
KernelConfig instance dataclass Configuration object for kernel dispatch
KernelResult dataclass Benchmark timing with speedup factor

Usage Examples

Manual Kernel Configuration

from unsloth.kernels.moe.grouped_gemm.kernels.tuning import (
    KernelConfigForward,
    KernelConfigBackward_dX,
    KernelConfigBackward_dW,
)

# Create forward config
config_fwd = KernelConfigForward(
    BLOCK_SIZE_M=64,
    BLOCK_SIZE_N=128,
    BLOCK_SIZE_K=64,
    num_warps=8,
    num_stages=4,
    permute_y=True,
)

# Create backward configs
config_dx = KernelConfigBackward_dX(
    BLOCK_SIZE_M=64,
    BLOCK_SIZE_N=64,
    BLOCK_SIZE_K=256,
)
config_dw = KernelConfigBackward_dW(
    BLOCK_SIZE_M=64,
    BLOCK_SIZE_N=64,
    BLOCK_SIZE_K=256,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment