Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Unslothai Unsloth MoE Ops

From Leeroopedia


Knowledge Sources
Domains MoE, Token_Routing
Last Updated 2026-02-07 08:40 GMT

Overview

Utility functions implementing fundamental MoE operations: token permutation, routing index computation, top-k expert selection, and a reference torch-native grouped GEMM.

Description

The moe_ops module provides the building blocks used by all MoE reference implementations. permute gathers tokens by expert assignment; unpermute scatters them back using index_copy_. calculate_topk computes expert selection with sigmoid or softmax activation and optional renormalization. get_routing_indices uses torch.histc and torch.argsort to compute per-expert token counts and gather indices. torch_grouped_gemm provides a pure PyTorch reference implementation of grouped matrix multiplication.

Usage

Import these functions when implementing custom MoE layers or when needing reference implementations for correctness validation against Triton kernels.

Code Reference

Source Location

Signature

def permute(
    X: torch.Tensor, gather_indices: torch.Tensor, topk: int,
) -> torch.Tensor:
    """Reorders tokens from token order to expert order."""

def unpermute(
    X: torch.Tensor, gather_indices: torch.Tensor,
) -> torch.Tensor:
    """Inverse of permute: returns tokens to original order."""

def calculate_topk(
    gating_output: torch.Tensor, top_k: int,
    use_sigmoid: bool, renormalize: bool,
    pre_act: bool = True, post_act: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Computes top-k expert selection and routing weights."""

@torch.no_grad()
def get_routing_indices(
    selected_experts, num_experts,
    return_scatter_indices: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
    """Computes token counts per expert and gather indices."""

def torch_grouped_gemm(
    X, W, m_sizes, transpose=True,
) -> torch.Tensor:
    """Reference grouped GEMM: iterates through experts for individual matmuls."""

Import

from unsloth.kernels.moe.grouped_gemm.reference.moe_ops import (
    permute, unpermute, calculate_topk,
    get_routing_indices, torch_grouped_gemm,
)

I/O Contract

Inputs (permute)

Name Type Required Description
X torch.Tensor Yes Input embeddings [num_tokens, hidden_dim]
gather_indices torch.Tensor Yes Token-to-expert mapping [num_tokens * topk]
topk int Yes Number of experts per token

Inputs (get_routing_indices)

Name Type Required Description
selected_experts torch.Tensor Yes Expert assignments [num_tokens * topk]
num_experts int Yes Total number of experts
return_scatter_indices bool No Also return inverse permutation (default: False)

Outputs

Name Type Description
permute returns torch.Tensor Tokens reordered by expert [total_tokens, hidden_dim]
get_routing_indices returns tuple (token_counts_by_expert, gather_indices[, scatter_indices])
torch_grouped_gemm returns torch.Tensor Y [M, N] result of grouped multiplications

Usage Examples

Token Permutation and Grouped GEMM

from unsloth.kernels.moe.grouped_gemm.reference.moe_ops import (
    permute, unpermute, get_routing_indices, torch_grouped_gemm,
)
import torch

num_tokens, hidden_dim, num_experts, topk = 1024, 4096, 8, 2
X = torch.randn(num_tokens, hidden_dim, device="cuda")
W = torch.randn(num_experts, 14336, hidden_dim, device="cuda")

# Compute routing
selected = torch.randint(0, num_experts, (num_tokens * topk,), device="cuda")
counts, indices = get_routing_indices(selected, num_experts)

# Permute tokens to expert order
X_perm = permute(X, indices, topk)

# Grouped GEMM
Y = torch_grouped_gemm(X_perm, W, counts)

# Unpermute back to token order
Y_out = unpermute(Y, indices)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment