Implementation:Unslothai Unsloth Grouped GEMM Interface
| Knowledge Sources | |
|---|---|
| Domains | MoE, Triton_Kernels, Linear_Algebra |
| Last Updated | 2026-02-07 08:40 GMT |
Overview
Concrete tool for executing grouped GEMM operations (forward and backward) on MoE expert layers with Triton kernel dispatch, autotuning, and torch.compile compatibility.
Description
The interface module provides the main Python-level API for invoking Triton grouped GEMM kernels. It includes grouped_gemm_forward, grouped_gemm_dX, and grouped_gemm_dW functions for forward pass and gradient computations, a GroupedGemm autograd Function for automatic differentiation, and a high-level grouped_gemm entry point that handles configuration validation and kernel dispatch. The module supports TMA (Tensor Memory Access) on SM>=9 GPUs, fused token permutation, and torch.compile tracing.
Usage
Import the grouped_gemm function when implementing MoE forward passes that need fused per-expert matrix multiplication with automatic gradient support.
Code Reference
Source Location
- Repository: Unslothai_Unsloth
- File: unsloth/kernels/moe/grouped_gemm/interface.py
- Lines: 1-1038
Signature
def grouped_gemm(
X: torch.Tensor,
W: torch.Tensor,
m_sizes: torch.Tensor,
topk: int,
gather_indices: torch.Tensor = None,
permute_x: bool = False,
permute_y: bool = False,
topk_weights=None,
fuse_mul_post=False,
kernel_config_fwd: KernelConfigForward = None,
kernel_config_bwd_dX: KernelConfigBackward_dX = None,
kernel_config_bwd_dW: KernelConfigBackward_dW = None,
autotune: bool = False,
is_first_gemm: bool = True,
dX_only: bool = False,
dW_only: bool = False,
) -> torch.Tensor:
"""
High-level grouped GEMM with autograd, config validation, and kernel dispatch.
"""
def grouped_gemm_forward(
X: torch.Tensor, W: torch.Tensor, topk: int,
m_sizes: torch.Tensor, gather_indices: torch.Tensor = None,
topk_weights: torch.Tensor = None,
permute_x: bool = False, permute_y: bool = False,
fuse_mul_post: bool = False, autotune: bool = False,
BLOCK_SIZE_M: int = 32, BLOCK_SIZE_N: int = 32, BLOCK_SIZE_K: int = 32,
num_warps: int = 4, num_stages: int = 2,
use_tma_load_w: bool = False, use_tma_load_x: bool = False,
use_tma_store: bool = False, flatten: bool = True, debug: bool = False,
) -> torch.Tensor:
"""Forward pass: X @ W^T per expert."""
class GroupedGemm(torch.autograd.Function):
"""Autograd wrapper for grouped GEMM forward/backward."""
Import
from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| X | torch.Tensor | Yes | Input activations [num_tokens, K] or [total_tokens, K] |
| W | torch.Tensor | Yes | Expert weights [num_experts, N, K] or flattened |
| m_sizes | torch.Tensor | Yes | Token counts per expert [num_experts] |
| topk | int | Yes | Number of experts per token |
| gather_indices | torch.Tensor | No | Token-to-expert mapping indices |
| permute_x | bool | No | Fuse input gather into kernel (default: False) |
| permute_y | bool | No | Fuse output scatter into kernel (default: False) |
| topk_weights | torch.Tensor | No | Routing weights for fused multiplication |
| autotune | bool | No | Use Triton autotuning (default: False) |
| kernel_config_fwd | KernelConfigForward | No | Manual forward kernel config |
Outputs
| Name | Type | Description |
|---|---|---|
| Y | torch.Tensor | Output activations [total_tokens, N] or permuted |
Usage Examples
Basic Grouped GEMM
from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm
from unsloth.kernels.moe.grouped_gemm.kernels.tuning import KernelConfigForward
import torch
# Expert weights: 8 experts, output_dim=14336, input_dim=4096
W = torch.randn(8, 14336, 4096, dtype=torch.bfloat16, device="cuda")
X = torch.randn(1024, 4096, dtype=torch.bfloat16, device="cuda")
m_sizes = torch.tensor([128]*8, device="cuda") # 128 tokens per expert
config_fwd = KernelConfigForward(BLOCK_SIZE_M=64, BLOCK_SIZE_N=128, BLOCK_SIZE_K=64)
Y = grouped_gemm(X, W, m_sizes, topk=2, kernel_config_fwd=config_fwd)