Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Unslothai Unsloth GEMM Backward Kernels

From Leeroopedia


Knowledge Sources
Domains MoE, Triton_Kernels, Backpropagation
Last Updated 2026-02-07 08:40 GMT

Overview

Triton JIT backward kernels computing input gradients (dX) and weight gradients (dW) for grouped GEMM operations in MoE expert layers.

Description

The backward module implements two Triton JIT kernels: _grouped_gemm_dX_kernel computes dY @ W^T per expert for input gradient propagation, and _grouped_gemm_dW_kernel computes X^T @ dY per expert for weight gradient accumulation. Both kernels use persistent tile scheduling across SMs, float32 accumulators, and support fused token permutation and optional TMA descriptors. Each kernel has an autotuned variant created via triton.autotune with pruned configuration search spaces.

Usage

These kernels are invoked internally by the GroupedGemm autograd Function during backward passes. They are not typically called directly by users.

Code Reference

Source Location

Signature

@triton.jit
def _grouped_gemm_dX_kernel(
    dY_ptr, w_ptr, dX_ptr, gather_indices_ptr, m_sizes_ptr,
    NUM_EXPERTS: tl.constexpr, NUM_TOKENS, TOPK: tl.constexpr,
    N: tl.constexpr, K: tl.constexpr, NUM_SMS,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    PERMUTE_X: tl.constexpr = False, PERMUTE_Y: tl.constexpr = False,
    USE_TMA_LOAD_W: tl.constexpr = False, USE_TMA_LOAD_dY: tl.constexpr = False,
    USE_TMA_STORE: tl.constexpr = False, FLATTEN: tl.constexpr = True,
) -> None:
    """Computes input gradients: dX = dY @ W^T per expert."""

@triton.jit
def _grouped_gemm_dW_kernel(
    x_ptr, dY_ptr, dW_ptr, m_sizes_ptr, gather_indices_ptr,
    NUM_TOKENS, TOPK: tl.constexpr, NUM_EXPERTS: tl.constexpr,
    N: tl.constexpr, K: tl.constexpr, NUM_SMS,
    BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    BLOCK_SIZE_M: tl.constexpr,
    PERMUTE_X: tl.constexpr = False, PERMUTE_Y: tl.constexpr = False,
    USE_TMA_LOAD_dY: tl.constexpr = False, USE_TMA_LOAD_X: tl.constexpr = False,
    USE_TMA_STORE: tl.constexpr = False, FLATTEN: tl.constexpr = True,
    acc_dtype: tl.constexpr = tl.float32,
) -> None:
    """Computes weight gradients: dW = X^T @ dY per expert."""

Import

from unsloth.kernels.moe.grouped_gemm.kernels.backward import (
    _grouped_gemm_dX_kernel,
    _grouped_gemm_dW_kernel,
    _autotuned_grouped_gemm_dX_kernel,
    _autotuned_grouped_gemm_dW_kernel,
)

I/O Contract

Inputs (dX kernel)

Name Type Required Description
dY_ptr pointer Yes Upstream gradients [M_total, N]
w_ptr pointer Yes Expert weights [E, N, K]
gather_indices_ptr pointer Yes Token-to-expert mapping
m_sizes_ptr pointer Yes Token counts per expert [E]

Inputs (dW kernel)

Name Type Required Description
x_ptr pointer Yes Input activations [M_total, K]
dY_ptr pointer Yes Upstream gradients [M_total, N]
m_sizes_ptr pointer Yes Token counts per expert [E]
gather_indices_ptr pointer Yes Token-to-expert mapping

Outputs

Name Type Description
dX torch.Tensor Input gradients [M_total, K]
dW torch.Tensor Weight gradients [E, N, K]

Usage Examples

Autotuned Variant (via autotune wrapper)

# These kernels are typically invoked via the interface module:
from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm_dX, grouped_gemm_dW

# Input gradient computation
dX = grouped_gemm_dX(dY, W, gather_indices, m_sizes, topk=2, autotune=True)

# Weight gradient computation
dW = grouped_gemm_dW(X, dY, m_sizes, gather_indices, topk=2, autotune=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment