Implementation:Unslothai Unsloth MoE Autotune Cache
| Knowledge Sources | |
|---|---|
| Domains | MoE, Kernel_Optimization |
| Last Updated | 2026-02-07 08:40 GMT |
Overview
Concrete tool for caching and retrieving MoE Triton kernel autotuning configurations to avoid repeated expensive autotuning across sessions.
Description
The autotune_cache module provides a persistent caching system for Triton grouped GEMM kernel configurations. It generates MD5 cache keys from model parameters (num_experts, hidden_dim, intermediate_dim, top_k, dtype, device capability), stores optimized kernel configs as JSON under ~/.cache/unsloth/moe_autotune/, and maintains in-memory caches for runtime access. When autotuning fails or is disabled via UNSLOTH_MOE_DISABLE_AUTOTUNE=1, it falls back to heuristic or default configurations.
Usage
Import the get_or_autotune_moe_kernels function when initializing MoE model layers to obtain optimized kernel configurations for forward and backward passes without re-running autotuning.
Code Reference
Source Location
- Repository: Unslothai_Unsloth
- File: unsloth/kernels/moe/autotune_cache.py
- Lines: 1-500
Signature
def get_or_autotune_moe_kernels(
num_experts: int,
hidden_dim: int,
intermediate_dim: int,
top_k: int,
dtype: torch.dtype,
force_autotune: bool = False,
seq_len: int = 8192,
) -> Tuple[Any, Any, Any]:
"""
Returns (config_fwd, config_bwd_dx, config_bwd_dw) kernel configs.
Checks in-memory cache, disk cache, runs autotuning, or falls back to defaults.
"""
def clear_cache() -> None:
"""Clears in-memory kernel config caches."""
def load_cached_config(cache_key: str) -> Optional[Dict[str, Any]]:
"""Loads config from disk cache file."""
def save_cached_config(
cache_key: str,
config_fwd: Any,
config_bwd_dx: Any,
config_bwd_dw: Any,
metadata: Dict[str, Any] = None,
) -> None:
"""Saves kernel configs to disk as JSON."""
Import
from unsloth.kernels.moe.autotune_cache import get_or_autotune_moe_kernels
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| num_experts | int | Yes | Number of experts in the MoE layer |
| hidden_dim | int | Yes | Hidden dimension size |
| intermediate_dim | int | Yes | Intermediate (FFN) dimension size |
| top_k | int | Yes | Number of experts per token |
| dtype | torch.dtype | Yes | Data type (e.g., torch.bfloat16) |
| force_autotune | bool | No | Skip caches and re-run autotuning (default: False) |
| seq_len | int | No | Sequence length for dummy tensors (default: 8192) |
Outputs
| Name | Type | Description |
|---|---|---|
| config_fwd | KernelConfigForward | Forward pass kernel configuration |
| config_bwd_dx | KernelConfigBackward_dX | Input gradient kernel configuration |
| config_bwd_dw | KernelConfigBackward_dW | Weight gradient kernel configuration |
Usage Examples
Get Cached or Autotuned Configs
from unsloth.kernels.moe.autotune_cache import get_or_autotune_moe_kernels
import torch
# Get kernel configs (cached or autotuned)
config_fwd, config_bwd_dx, config_bwd_dw = get_or_autotune_moe_kernels(
num_experts=8,
hidden_dim=4096,
intermediate_dim=14336,
top_k=2,
dtype=torch.bfloat16,
)
Disable Autotuning via Environment
# Use heuristic configs instead of autotuning
export UNSLOTH_MOE_DISABLE_AUTOTUNE=1