Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Unslothai Unsloth Qwen3 MoE Block

From Leeroopedia


Knowledge Sources
Domains MoE, Model_Architecture
Last Updated 2026-02-07 08:40 GMT

Overview

Reference implementations of Qwen3 MoE blocks using both PyTorch-native and Triton-optimized fused grouped GEMM operations.

Description

This module provides Qwen3MoeGroupedGEMMBlock (torch-native) and Qwen3MoeFusedGroupedGEMMBlock (Triton-fused) implementations compatible with HuggingFace's Qwen3MoeSparseMoeBlock. The torch version extracts and stacks expert weights into [E, N, K] tensors with softmax routing and top-k normalization. The fused version replaces the two GEMM calls with the fused grouped_gemm interface, supporting fused permute_x on the first GEMM and permute_y on the second, with configurable gradient computation modes.

Usage

Import these classes when benchmarking or deploying Qwen3 MoE models with Triton-optimized grouped GEMM kernels, or use from_hf to convert existing HuggingFace MoE blocks.

Code Reference

Source Location

Signature

class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
    def __init__(
        self, config: Qwen3MoeConfig,
        gate: torch.Tensor, gate_up_proj: torch.Tensor,
        down_proj: torch.Tensor,
    ):
        """PyTorch-native Qwen3 MoE block."""

    @classmethod
    def from_hf(cls, moe_block: Qwen3MoeSparseMoeBlock) -> "Qwen3MoeGroupedGEMMBlock":
        """Create from HuggingFace MoE block."""

    def forward(self, hidden_states: torch.Tensor) -> Tuple[GroupedGEMMResult, torch.Tensor]:
        """Forward pass with full intermediate result tracking."""

class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
    def __init__(
        self, config: Qwen3MoeConfig,
        gate: torch.Tensor, gate_up_proj: torch.Tensor,
        down_proj: torch.Tensor,
        permute_x: bool = True, permute_y: bool = True,
        autotune: bool = True,
        kernel_config_fwd: KernelConfigForward = None,
        kernel_config_bwd_dW: KernelConfigBackward_dW = None,
        kernel_config_bwd_dX: KernelConfigBackward_dX = None,
        dW_only: bool = False, dX_only: bool = False,
    ):
        """Triton-optimized Qwen3 MoE block."""

    @classmethod
    def from_hf(cls, moe_block, **kwargs) -> "Qwen3MoeFusedGroupedGEMMBlock":
        """Create fused block from HuggingFace MoE block."""

    def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass using fused Triton grouped GEMM."""

Import

from unsloth.kernels.moe.grouped_gemm.reference.layers.qwen3_moe import (
    Qwen3MoeGroupedGEMMBlock,
    Qwen3MoeFusedGroupedGEMMBlock,
)

I/O Contract

Inputs

Name Type Required Description
hidden_states torch.Tensor Yes Input token representations [batch, seq_len, hidden_size]
config Qwen3MoeConfig Yes Qwen3 MoE model configuration
gate torch.Tensor Yes Router gate weight matrix
gate_up_proj torch.Tensor Yes Stacked expert gate/up projections [E, 2*intermediate, hidden]
down_proj torch.Tensor Yes Expert down projections [E, hidden, intermediate]

Outputs

Name Type Description
hidden_states torch.Tensor Processed hidden states [batch, seq_len, hidden_size]
router_logits torch.Tensor Raw router logits for auxiliary loss

Usage Examples

Convert HuggingFace Block to Fused

from unsloth.kernels.moe.grouped_gemm.reference.layers.qwen3_moe import (
    Qwen3MoeFusedGroupedGEMMBlock,
)
from transformers import AutoModel

model = AutoModel.from_pretrained("Qwen/Qwen3-MoE-A3B")
hf_moe_block = model.layers[0].mlp

# Convert to Triton-fused block
fused_block = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
    hf_moe_block,
    permute_x=True,
    permute_y=True,
    autotune=True,
)

hidden = torch.randn(1, 512, 2048, dtype=torch.bfloat16, device="cuda")
output, router_logits = fused_block(hidden)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment