Implementation:Unslothai Unsloth Qwen3 MoE Block
| Knowledge Sources | |
|---|---|
| Domains | MoE, Model_Architecture |
| Last Updated | 2026-02-07 08:40 GMT |
Overview
Reference implementations of Qwen3 MoE blocks using both PyTorch-native and Triton-optimized fused grouped GEMM operations.
Description
This module provides Qwen3MoeGroupedGEMMBlock (torch-native) and Qwen3MoeFusedGroupedGEMMBlock (Triton-fused) implementations compatible with HuggingFace's Qwen3MoeSparseMoeBlock. The torch version extracts and stacks expert weights into [E, N, K] tensors with softmax routing and top-k normalization. The fused version replaces the two GEMM calls with the fused grouped_gemm interface, supporting fused permute_x on the first GEMM and permute_y on the second, with configurable gradient computation modes.
Usage
Import these classes when benchmarking or deploying Qwen3 MoE models with Triton-optimized grouped GEMM kernels, or use from_hf to convert existing HuggingFace MoE blocks.
Code Reference
Source Location
- Repository: Unslothai_Unsloth
- File: unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py
- Lines: 1-348
Signature
class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
def __init__(
self, config: Qwen3MoeConfig,
gate: torch.Tensor, gate_up_proj: torch.Tensor,
down_proj: torch.Tensor,
):
"""PyTorch-native Qwen3 MoE block."""
@classmethod
def from_hf(cls, moe_block: Qwen3MoeSparseMoeBlock) -> "Qwen3MoeGroupedGEMMBlock":
"""Create from HuggingFace MoE block."""
def forward(self, hidden_states: torch.Tensor) -> Tuple[GroupedGEMMResult, torch.Tensor]:
"""Forward pass with full intermediate result tracking."""
class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
def __init__(
self, config: Qwen3MoeConfig,
gate: torch.Tensor, gate_up_proj: torch.Tensor,
down_proj: torch.Tensor,
permute_x: bool = True, permute_y: bool = True,
autotune: bool = True,
kernel_config_fwd: KernelConfigForward = None,
kernel_config_bwd_dW: KernelConfigBackward_dW = None,
kernel_config_bwd_dX: KernelConfigBackward_dX = None,
dW_only: bool = False, dX_only: bool = False,
):
"""Triton-optimized Qwen3 MoE block."""
@classmethod
def from_hf(cls, moe_block, **kwargs) -> "Qwen3MoeFusedGroupedGEMMBlock":
"""Create fused block from HuggingFace MoE block."""
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass using fused Triton grouped GEMM."""
Import
from unsloth.kernels.moe.grouped_gemm.reference.layers.qwen3_moe import (
Qwen3MoeGroupedGEMMBlock,
Qwen3MoeFusedGroupedGEMMBlock,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hidden_states | torch.Tensor | Yes | Input token representations [batch, seq_len, hidden_size] |
| config | Qwen3MoeConfig | Yes | Qwen3 MoE model configuration |
| gate | torch.Tensor | Yes | Router gate weight matrix |
| gate_up_proj | torch.Tensor | Yes | Stacked expert gate/up projections [E, 2*intermediate, hidden] |
| down_proj | torch.Tensor | Yes | Expert down projections [E, hidden, intermediate] |
Outputs
| Name | Type | Description |
|---|---|---|
| hidden_states | torch.Tensor | Processed hidden states [batch, seq_len, hidden_size] |
| router_logits | torch.Tensor | Raw router logits for auxiliary loss |
Usage Examples
Convert HuggingFace Block to Fused
from unsloth.kernels.moe.grouped_gemm.reference.layers.qwen3_moe import (
Qwen3MoeFusedGroupedGEMMBlock,
)
from transformers import AutoModel
model = AutoModel.from_pretrained("Qwen/Qwen3-MoE-A3B")
hf_moe_block = model.layers[0].mlp
# Convert to Triton-fused block
fused_block = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
hf_moe_block,
permute_x=True,
permute_y=True,
autotune=True,
)
hidden = torch.randn(1, 512, 2048, dtype=torch.bfloat16, device="cuda")
output, router_logits = fused_block(hidden)