Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Unslothai Unsloth Llama4 MoE Block

From Leeroopedia


Knowledge Sources
Domains MoE, Model_Architecture
Last Updated 2026-02-07 08:40 GMT

Overview

Reference implementations of Llama4 MoE blocks using both PyTorch-native and Triton-optimized grouped GEMM operations.

Description

This module provides Llama4GroupedGemmTextMoe (torch-native) and Llama4TritonTextMoe (Triton-fused) implementations extending HuggingFace's Llama4TextMoe. The torch version restructures expert weights to [E, N, K] layout, implements sigmoid-based routing with optional shared expert computation overlapping via CUDA streams. The Triton version replaces torch GEMM calls with the fused grouped_gemm interface, supporting kernel configuration and autotuning.

Usage

Import these classes when benchmarking or deploying Llama4 MoE models with Triton-optimized grouped GEMM kernels.

Code Reference

Source Location

Signature

class Llama4GroupedGemmTextMoe(Llama4TextMoe):
    def __init__(
        self, config: Llama4TextConfig,
        overlap_router_shared=False, verbose=False, debug=False,
    ):
        """PyTorch-native Llama4 MoE block."""

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Forward pass through expert routing and GEMM."""

class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
    def __init__(
        self, config: Llama4TextConfig,
        overlap_router_shared=False,
        permute_x: bool = False, permute_y: bool = True,
        autotune: bool = True,
        kernel_config_fwd: KernelConfigForward = None,
        kernel_config_bwd_dW: KernelConfigBackward_dW = None,
        kernel_config_bwd_dX: KernelConfigBackward_dX = None,
        dW_only: bool = False, dX_only: bool = False,
        verbose=False,
    ):
        """Triton-optimized Llama4 MoE block."""

    def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass using fused Triton grouped GEMM."""

Import

from unsloth.kernels.moe.grouped_gemm.reference.layers.llama4_moe import (
    Llama4GroupedGemmTextMoe,
    Llama4TritonTextMoe,
)

I/O Contract

Inputs

Name Type Required Description
hidden_states torch.Tensor Yes Input token representations [batch, seq_len, hidden_dim]
config Llama4TextConfig Yes Llama4 model configuration
permute_x bool No Fuse input permutation in GEMM (default: False)
permute_y bool No Fuse output permutation in GEMM (default: True)
autotune bool No Enable kernel autotuning (default: True)

Outputs

Name Type Description
output torch.Tensor Processed hidden states [batch, seq_len, hidden_dim]
routing_weights torch.Tensor Expert routing weights (Triton variant)

Usage Examples

Triton-Optimized Llama4 MoE

from unsloth.kernels.moe.grouped_gemm.reference.layers.llama4_moe import (
    Llama4TritonTextMoe,
)
from transformers.models.llama4 import Llama4TextConfig

config = Llama4TextConfig(
    hidden_size=4096,
    intermediate_size=14336,
    num_local_experts=8,
    num_experts_per_tok=2,
)

moe = Llama4TritonTextMoe(config, autotune=True, permute_y=True).cuda()
hidden_states = torch.randn(1, 512, 4096, dtype=torch.bfloat16, device="cuda")
output, routing = moe(hidden_states)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment