Implementation:Unslothai Unsloth MoE Fused Block
| Knowledge Sources | |
|---|---|
| Domains | MoE, Model_Architecture |
| Last Updated | 2026-02-07 08:40 GMT |
Overview
Production-oriented Qwen3 MoE fused block implementation replacing torch-native grouped GEMM with Triton fused kernels.
Description
The moe_block module provides a cleaner, production-ready Qwen3MoeFusedGroupedGEMMBlock that extends Qwen3MoeGroupedGEMMBlock by overriding the forward method to call Triton's grouped_gemm interface for both gate_up_proj and down_proj GEMMs. It supports conditional external permutation when kernel-fused permutation is disabled, and provides from_hf factory method for converting HuggingFace MoE blocks.
Usage
Import this class as a production-ready integration reference for deploying Qwen3 MoE models with Triton-optimized kernels. Prefer this over the debug-heavy version in reference/layers/qwen3_moe.py.
Code Reference
Source Location
- Repository: Unslothai_Unsloth
- File: unsloth/kernels/moe/grouped_gemm/reference/moe_block.py
- Lines: 1-161
Signature
class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
def __init__(
self, config: Qwen3MoeConfig,
gate: torch.Tensor, gate_up_proj: torch.Tensor,
down_proj: torch.Tensor,
permute_x: bool = True, permute_y: bool = True,
autotune: bool = True,
kernel_config_fwd: KernelConfigForward = None,
kernel_config_bwd_dW: KernelConfigBackward_dW = None,
kernel_config_bwd_dX: KernelConfigBackward_dX = None,
dW_only: bool = False, dX_only: bool = False,
):
"""Production Qwen3 MoE block with fused Triton kernels."""
@classmethod
def from_hf(cls, moe_block: Qwen3MoeSparseMoeBlock, **kwargs):
"""Factory method from HuggingFace block."""
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Forward pass with fused Triton grouped GEMM."""
Import
from unsloth.kernels.moe.grouped_gemm.reference.moe_block import (
Qwen3MoeFusedGroupedGEMMBlock,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hidden_states | torch.Tensor | Yes | Input token representations [batch, seq_len, hidden_dim] |
| config | Qwen3MoeConfig | Yes | Qwen3 MoE model configuration |
| permute_x | bool | No | Fuse input permutation (default: True) |
| permute_y | bool | No | Fuse output permutation (default: True) |
| autotune | bool | No | Enable kernel autotuning (default: True) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Tuple[torch.Tensor, torch.Tensor] | (hidden_states, router_logits) |
Usage Examples
Production Deployment
from unsloth.kernels.moe.grouped_gemm.reference.moe_block import (
Qwen3MoeFusedGroupedGEMMBlock,
)
from transformers import AutoModel
model = AutoModel.from_pretrained("Qwen/Qwen3-MoE-A3B")
hf_block = model.layers[0].mlp
# Convert to production fused block
fused = Qwen3MoeFusedGroupedGEMMBlock.from_hf(hf_block, autotune=True)
output, logits = fused(hidden_states)