Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Unslothai Unsloth GEMM Forward Kernel

From Leeroopedia


Knowledge Sources
Domains MoE, Triton_Kernels, Linear_Algebra
Last Updated 2026-02-07 08:40 GMT

Overview

Triton JIT forward kernel computing X @ W^T for all MoE experts in a single fused pass with optional token permutation and weight fusion.

Description

The forward module implements _grouped_gemm_forward_kernel, a Triton JIT kernel that replaces sequential per-expert GEMM loops with a single fused kernel. It uses persistent tile scheduling where each SM processes tiles across all expert groups, supports fused token permutation (gather/scatter on load/store), fused top-k weight multiplication, optional TMA descriptor-based loads, and float32 accumulation. An autotuned variant is provided via triton.autotune.

Usage

This kernel is invoked internally by grouped_gemm_forward in the interface module during MoE forward passes. It is not typically called directly.

Code Reference

Source Location

Signature

@triton.jit
def _grouped_gemm_forward_kernel(
    x_ptr, w_ptr, y_ptr, m_sizes_ptr,
    gather_indices_ptr, topk_weights_ptr,
    NUM_EXPERTS: tl.constexpr, NUM_TOKENS,
    TOPK: tl.constexpr, N: tl.constexpr, K: tl.constexpr, NUM_SMS,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    PERMUTE_X: tl.constexpr = False, PERMUTE_Y: tl.constexpr = False,
    FUSE_MUL_PRE: tl.constexpr = False, FUSE_MUL_POST: tl.constexpr = False,
    USE_FAST_ACCUM: tl.constexpr = False,
    USE_TMA_LOAD_W: tl.constexpr = False, USE_TMA_LOAD_X: tl.constexpr = False,
    USE_TMA_STORE: tl.constexpr = False,
    acc_dtype: tl.constexpr = tl.float32, FLATTEN: tl.constexpr = True,
) -> None:
    """Forward pass: Y = X @ W^T per expert with fused permutation."""

Import

from unsloth.kernels.moe.grouped_gemm.kernels.forward import (
    _grouped_gemm_forward_kernel,
    _autotuned_grouped_gemm_forward_kernel,
)

I/O Contract

Inputs

Name Type Required Description
x_ptr pointer Yes Input activations [M_total, K] or permuted
w_ptr pointer Yes Expert weights [E, N, K]
m_sizes_ptr pointer Yes Token counts per expert [E]
gather_indices_ptr pointer No Token-to-expert mapping indices
topk_weights_ptr pointer No Routing weights for fused multiplication

Outputs

Name Type Description
y_ptr pointer Output activations [M_total, N]

Usage Examples

Via Interface Module

from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm_forward
import torch

X = torch.randn(2048, 4096, dtype=torch.bfloat16, device="cuda")
W = torch.randn(8, 14336, 4096, dtype=torch.bfloat16, device="cuda")
m_sizes = torch.tensor([256]*8, device="cuda")

Y = grouped_gemm_forward(X, W, topk=2, m_sizes=m_sizes, autotune=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment