Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Unslothai Unsloth Grouped GEMM Interface

From Leeroopedia


Knowledge Sources
Domains MoE, Triton_Kernels, Linear_Algebra
Last Updated 2026-02-07 08:40 GMT

Overview

Concrete tool for executing grouped GEMM operations (forward and backward) on MoE expert layers with Triton kernel dispatch, autotuning, and torch.compile compatibility.

Description

The interface module provides the main Python-level API for invoking Triton grouped GEMM kernels. It includes grouped_gemm_forward, grouped_gemm_dX, and grouped_gemm_dW functions for forward pass and gradient computations, a GroupedGemm autograd Function for automatic differentiation, and a high-level grouped_gemm entry point that handles configuration validation and kernel dispatch. The module supports TMA (Tensor Memory Access) on SM>=9 GPUs, fused token permutation, and torch.compile tracing.

Usage

Import the grouped_gemm function when implementing MoE forward passes that need fused per-expert matrix multiplication with automatic gradient support.

Code Reference

Source Location

Signature

def grouped_gemm(
    X: torch.Tensor,
    W: torch.Tensor,
    m_sizes: torch.Tensor,
    topk: int,
    gather_indices: torch.Tensor = None,
    permute_x: bool = False,
    permute_y: bool = False,
    topk_weights=None,
    fuse_mul_post=False,
    kernel_config_fwd: KernelConfigForward = None,
    kernel_config_bwd_dX: KernelConfigBackward_dX = None,
    kernel_config_bwd_dW: KernelConfigBackward_dW = None,
    autotune: bool = False,
    is_first_gemm: bool = True,
    dX_only: bool = False,
    dW_only: bool = False,
) -> torch.Tensor:
    """
    High-level grouped GEMM with autograd, config validation, and kernel dispatch.
    """

def grouped_gemm_forward(
    X: torch.Tensor, W: torch.Tensor, topk: int,
    m_sizes: torch.Tensor, gather_indices: torch.Tensor = None,
    topk_weights: torch.Tensor = None,
    permute_x: bool = False, permute_y: bool = False,
    fuse_mul_post: bool = False, autotune: bool = False,
    BLOCK_SIZE_M: int = 32, BLOCK_SIZE_N: int = 32, BLOCK_SIZE_K: int = 32,
    num_warps: int = 4, num_stages: int = 2,
    use_tma_load_w: bool = False, use_tma_load_x: bool = False,
    use_tma_store: bool = False, flatten: bool = True, debug: bool = False,
) -> torch.Tensor:
    """Forward pass: X @ W^T per expert."""

class GroupedGemm(torch.autograd.Function):
    """Autograd wrapper for grouped GEMM forward/backward."""

Import

from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm

I/O Contract

Inputs

Name Type Required Description
X torch.Tensor Yes Input activations [num_tokens, K] or [total_tokens, K]
W torch.Tensor Yes Expert weights [num_experts, N, K] or flattened
m_sizes torch.Tensor Yes Token counts per expert [num_experts]
topk int Yes Number of experts per token
gather_indices torch.Tensor No Token-to-expert mapping indices
permute_x bool No Fuse input gather into kernel (default: False)
permute_y bool No Fuse output scatter into kernel (default: False)
topk_weights torch.Tensor No Routing weights for fused multiplication
autotune bool No Use Triton autotuning (default: False)
kernel_config_fwd KernelConfigForward No Manual forward kernel config

Outputs

Name Type Description
Y torch.Tensor Output activations [total_tokens, N] or permuted

Usage Examples

Basic Grouped GEMM

from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm
from unsloth.kernels.moe.grouped_gemm.kernels.tuning import KernelConfigForward
import torch

# Expert weights: 8 experts, output_dim=14336, input_dim=4096
W = torch.randn(8, 14336, 4096, dtype=torch.bfloat16, device="cuda")
X = torch.randn(1024, 4096, dtype=torch.bfloat16, device="cuda")
m_sizes = torch.tensor([128]*8, device="cuda")  # 128 tokens per expert

config_fwd = KernelConfigForward(BLOCK_SIZE_M=64, BLOCK_SIZE_N=128, BLOCK_SIZE_K=64)
Y = grouped_gemm(X, W, m_sizes, topk=2, kernel_config_fwd=config_fwd)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment