Implementation:Unslothai Unsloth GEMM Backward Kernels
| Knowledge Sources | |
|---|---|
| Domains | MoE, Triton_Kernels, Backpropagation |
| Last Updated | 2026-02-07 08:40 GMT |
Overview
Triton JIT backward kernels computing input gradients (dX) and weight gradients (dW) for grouped GEMM operations in MoE expert layers.
Description
The backward module implements two Triton JIT kernels: _grouped_gemm_dX_kernel computes dY @ W^T per expert for input gradient propagation, and _grouped_gemm_dW_kernel computes X^T @ dY per expert for weight gradient accumulation. Both kernels use persistent tile scheduling across SMs, float32 accumulators, and support fused token permutation and optional TMA descriptors. Each kernel has an autotuned variant created via triton.autotune with pruned configuration search spaces.
Usage
These kernels are invoked internally by the GroupedGemm autograd Function during backward passes. They are not typically called directly by users.
Code Reference
Source Location
- Repository: Unslothai_Unsloth
- File: unsloth/kernels/moe/grouped_gemm/kernels/backward.py
- Lines: 1-505
Signature
@triton.jit
def _grouped_gemm_dX_kernel(
dY_ptr, w_ptr, dX_ptr, gather_indices_ptr, m_sizes_ptr,
NUM_EXPERTS: tl.constexpr, NUM_TOKENS, TOPK: tl.constexpr,
N: tl.constexpr, K: tl.constexpr, NUM_SMS,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
PERMUTE_X: tl.constexpr = False, PERMUTE_Y: tl.constexpr = False,
USE_TMA_LOAD_W: tl.constexpr = False, USE_TMA_LOAD_dY: tl.constexpr = False,
USE_TMA_STORE: tl.constexpr = False, FLATTEN: tl.constexpr = True,
) -> None:
"""Computes input gradients: dX = dY @ W^T per expert."""
@triton.jit
def _grouped_gemm_dW_kernel(
x_ptr, dY_ptr, dW_ptr, m_sizes_ptr, gather_indices_ptr,
NUM_TOKENS, TOPK: tl.constexpr, NUM_EXPERTS: tl.constexpr,
N: tl.constexpr, K: tl.constexpr, NUM_SMS,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
PERMUTE_X: tl.constexpr = False, PERMUTE_Y: tl.constexpr = False,
USE_TMA_LOAD_dY: tl.constexpr = False, USE_TMA_LOAD_X: tl.constexpr = False,
USE_TMA_STORE: tl.constexpr = False, FLATTEN: tl.constexpr = True,
acc_dtype: tl.constexpr = tl.float32,
) -> None:
"""Computes weight gradients: dW = X^T @ dY per expert."""
Import
from unsloth.kernels.moe.grouped_gemm.kernels.backward import (
_grouped_gemm_dX_kernel,
_grouped_gemm_dW_kernel,
_autotuned_grouped_gemm_dX_kernel,
_autotuned_grouped_gemm_dW_kernel,
)
I/O Contract
Inputs (dX kernel)
| Name | Type | Required | Description |
|---|---|---|---|
| dY_ptr | pointer | Yes | Upstream gradients [M_total, N] |
| w_ptr | pointer | Yes | Expert weights [E, N, K] |
| gather_indices_ptr | pointer | Yes | Token-to-expert mapping |
| m_sizes_ptr | pointer | Yes | Token counts per expert [E] |
Inputs (dW kernel)
| Name | Type | Required | Description |
|---|---|---|---|
| x_ptr | pointer | Yes | Input activations [M_total, K] |
| dY_ptr | pointer | Yes | Upstream gradients [M_total, N] |
| m_sizes_ptr | pointer | Yes | Token counts per expert [E] |
| gather_indices_ptr | pointer | Yes | Token-to-expert mapping |
Outputs
| Name | Type | Description |
|---|---|---|
| dX | torch.Tensor | Input gradients [M_total, K] |
| dW | torch.Tensor | Weight gradients [E, N, K] |
Usage Examples
Autotuned Variant (via autotune wrapper)
# These kernels are typically invoked via the interface module:
from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm_dX, grouped_gemm_dW
# Input gradient computation
dX = grouped_gemm_dX(dY, W, gather_indices, m_sizes, topk=2, autotune=True)
# Weight gradient computation
dW = grouped_gemm_dW(X, dY, m_sizes, gather_indices, topk=2, autotune=True)