Implementation:Unslothai Unsloth Llama4 MoE Block
| Knowledge Sources | |
|---|---|
| Domains | MoE, Model_Architecture |
| Last Updated | 2026-02-07 08:40 GMT |
Overview
Reference implementations of Llama4 MoE blocks using both PyTorch-native and Triton-optimized grouped GEMM operations.
Description
This module provides Llama4GroupedGemmTextMoe (torch-native) and Llama4TritonTextMoe (Triton-fused) implementations extending HuggingFace's Llama4TextMoe. The torch version restructures expert weights to [E, N, K] layout, implements sigmoid-based routing with optional shared expert computation overlapping via CUDA streams. The Triton version replaces torch GEMM calls with the fused grouped_gemm interface, supporting kernel configuration and autotuning.
Usage
Import these classes when benchmarking or deploying Llama4 MoE models with Triton-optimized grouped GEMM kernels.
Code Reference
Source Location
- Repository: Unslothai_Unsloth
- File: unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py
- Lines: 1-437
Signature
class Llama4GroupedGemmTextMoe(Llama4TextMoe):
def __init__(
self, config: Llama4TextConfig,
overlap_router_shared=False, verbose=False, debug=False,
):
"""PyTorch-native Llama4 MoE block."""
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Forward pass through expert routing and GEMM."""
class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
def __init__(
self, config: Llama4TextConfig,
overlap_router_shared=False,
permute_x: bool = False, permute_y: bool = True,
autotune: bool = True,
kernel_config_fwd: KernelConfigForward = None,
kernel_config_bwd_dW: KernelConfigBackward_dW = None,
kernel_config_bwd_dX: KernelConfigBackward_dX = None,
dW_only: bool = False, dX_only: bool = False,
verbose=False,
):
"""Triton-optimized Llama4 MoE block."""
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass using fused Triton grouped GEMM."""
Import
from unsloth.kernels.moe.grouped_gemm.reference.layers.llama4_moe import (
Llama4GroupedGemmTextMoe,
Llama4TritonTextMoe,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hidden_states | torch.Tensor | Yes | Input token representations [batch, seq_len, hidden_dim] |
| config | Llama4TextConfig | Yes | Llama4 model configuration |
| permute_x | bool | No | Fuse input permutation in GEMM (default: False) |
| permute_y | bool | No | Fuse output permutation in GEMM (default: True) |
| autotune | bool | No | Enable kernel autotuning (default: True) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Processed hidden states [batch, seq_len, hidden_dim] |
| routing_weights | torch.Tensor | Expert routing weights (Triton variant) |
Usage Examples
Triton-Optimized Llama4 MoE
from unsloth.kernels.moe.grouped_gemm.reference.layers.llama4_moe import (
Llama4TritonTextMoe,
)
from transformers.models.llama4 import Llama4TextConfig
config = Llama4TextConfig(
hidden_size=4096,
intermediate_size=14336,
num_local_experts=8,
num_experts_per_tok=2,
)
moe = Llama4TritonTextMoe(config, autotune=True, permute_y=True).cuda()
hidden_states = torch.randn(1, 512, 4096, dtype=torch.bfloat16, device="cuda")
output, routing = moe(hidden_states)