Implementation:Unslothai Unsloth GEMM Forward Kernel
| Knowledge Sources | |
|---|---|
| Domains | MoE, Triton_Kernels, Linear_Algebra |
| Last Updated | 2026-02-07 08:40 GMT |
Overview
Triton JIT forward kernel computing X @ W^T for all MoE experts in a single fused pass with optional token permutation and weight fusion.
Description
The forward module implements _grouped_gemm_forward_kernel, a Triton JIT kernel that replaces sequential per-expert GEMM loops with a single fused kernel. It uses persistent tile scheduling where each SM processes tiles across all expert groups, supports fused token permutation (gather/scatter on load/store), fused top-k weight multiplication, optional TMA descriptor-based loads, and float32 accumulation. An autotuned variant is provided via triton.autotune.
Usage
This kernel is invoked internally by grouped_gemm_forward in the interface module during MoE forward passes. It is not typically called directly.
Code Reference
Source Location
- Repository: Unslothai_Unsloth
- File: unsloth/kernels/moe/grouped_gemm/kernels/forward.py
- Lines: 1-267
Signature
@triton.jit
def _grouped_gemm_forward_kernel(
x_ptr, w_ptr, y_ptr, m_sizes_ptr,
gather_indices_ptr, topk_weights_ptr,
NUM_EXPERTS: tl.constexpr, NUM_TOKENS,
TOPK: tl.constexpr, N: tl.constexpr, K: tl.constexpr, NUM_SMS,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
PERMUTE_X: tl.constexpr = False, PERMUTE_Y: tl.constexpr = False,
FUSE_MUL_PRE: tl.constexpr = False, FUSE_MUL_POST: tl.constexpr = False,
USE_FAST_ACCUM: tl.constexpr = False,
USE_TMA_LOAD_W: tl.constexpr = False, USE_TMA_LOAD_X: tl.constexpr = False,
USE_TMA_STORE: tl.constexpr = False,
acc_dtype: tl.constexpr = tl.float32, FLATTEN: tl.constexpr = True,
) -> None:
"""Forward pass: Y = X @ W^T per expert with fused permutation."""
Import
from unsloth.kernels.moe.grouped_gemm.kernels.forward import (
_grouped_gemm_forward_kernel,
_autotuned_grouped_gemm_forward_kernel,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x_ptr | pointer | Yes | Input activations [M_total, K] or permuted |
| w_ptr | pointer | Yes | Expert weights [E, N, K] |
| m_sizes_ptr | pointer | Yes | Token counts per expert [E] |
| gather_indices_ptr | pointer | No | Token-to-expert mapping indices |
| topk_weights_ptr | pointer | No | Routing weights for fused multiplication |
Outputs
| Name | Type | Description |
|---|---|---|
| y_ptr | pointer | Output activations [M_total, N] |
Usage Examples
Via Interface Module
from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm_forward
import torch
X = torch.randn(2048, 4096, dtype=torch.bfloat16, device="cuda")
W = torch.randn(8, 14336, 4096, dtype=torch.bfloat16, device="cuda")
m_sizes = torch.tensor([256]*8, device="cuda")
Y = grouped_gemm_forward(X, W, topk=2, m_sizes=m_sizes, autotune=True)