Implementation:Unslothai Unsloth MoE Ops
| Knowledge Sources | |
|---|---|
| Domains | MoE, Token_Routing |
| Last Updated | 2026-02-07 08:40 GMT |
Overview
Utility functions implementing fundamental MoE operations: token permutation, routing index computation, top-k expert selection, and a reference torch-native grouped GEMM.
Description
The moe_ops module provides the building blocks used by all MoE reference implementations. permute gathers tokens by expert assignment; unpermute scatters them back using index_copy_. calculate_topk computes expert selection with sigmoid or softmax activation and optional renormalization. get_routing_indices uses torch.histc and torch.argsort to compute per-expert token counts and gather indices. torch_grouped_gemm provides a pure PyTorch reference implementation of grouped matrix multiplication.
Usage
Import these functions when implementing custom MoE layers or when needing reference implementations for correctness validation against Triton kernels.
Code Reference
Source Location
- Repository: Unslothai_Unsloth
- File: unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py
- Lines: 1-151
Signature
def permute(
X: torch.Tensor, gather_indices: torch.Tensor, topk: int,
) -> torch.Tensor:
"""Reorders tokens from token order to expert order."""
def unpermute(
X: torch.Tensor, gather_indices: torch.Tensor,
) -> torch.Tensor:
"""Inverse of permute: returns tokens to original order."""
def calculate_topk(
gating_output: torch.Tensor, top_k: int,
use_sigmoid: bool, renormalize: bool,
pre_act: bool = True, post_act: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes top-k expert selection and routing weights."""
@torch.no_grad()
def get_routing_indices(
selected_experts, num_experts,
return_scatter_indices: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Computes token counts per expert and gather indices."""
def torch_grouped_gemm(
X, W, m_sizes, transpose=True,
) -> torch.Tensor:
"""Reference grouped GEMM: iterates through experts for individual matmuls."""
Import
from unsloth.kernels.moe.grouped_gemm.reference.moe_ops import (
permute, unpermute, calculate_topk,
get_routing_indices, torch_grouped_gemm,
)
I/O Contract
Inputs (permute)
| Name | Type | Required | Description |
|---|---|---|---|
| X | torch.Tensor | Yes | Input embeddings [num_tokens, hidden_dim] |
| gather_indices | torch.Tensor | Yes | Token-to-expert mapping [num_tokens * topk] |
| topk | int | Yes | Number of experts per token |
Inputs (get_routing_indices)
| Name | Type | Required | Description |
|---|---|---|---|
| selected_experts | torch.Tensor | Yes | Expert assignments [num_tokens * topk] |
| num_experts | int | Yes | Total number of experts |
| return_scatter_indices | bool | No | Also return inverse permutation (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| permute returns | torch.Tensor | Tokens reordered by expert [total_tokens, hidden_dim] |
| get_routing_indices returns | tuple | (token_counts_by_expert, gather_indices[, scatter_indices]) |
| torch_grouped_gemm returns | torch.Tensor | Y [M, N] result of grouped multiplications |
Usage Examples
Token Permutation and Grouped GEMM
from unsloth.kernels.moe.grouped_gemm.reference.moe_ops import (
permute, unpermute, get_routing_indices, torch_grouped_gemm,
)
import torch
num_tokens, hidden_dim, num_experts, topk = 1024, 4096, 8, 2
X = torch.randn(num_tokens, hidden_dim, device="cuda")
W = torch.randn(num_experts, 14336, hidden_dim, device="cuda")
# Compute routing
selected = torch.randint(0, num_experts, (num_tokens * topk,), device="cuda")
counts, indices = get_routing_indices(selected, num_experts)
# Permute tokens to expert order
X_perm = permute(X, indices, topk)
# Grouped GEMM
Y = torch_grouped_gemm(X_perm, W, counts)
# Unpermute back to token order
Y_out = unpermute(Y, indices)