Implementation:Turboderp org Exllamav2 ExLlamaV2MoEMLP
| Knowledge Sources | |
|---|---|
| Domains | Mixture of Experts, Feed-Forward Network |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
ExLlamaV2MoEMLP implements a Mixture-of-Experts feed-forward layer with dynamic expert routing, supporting both quantized CUDA kernels and a pure PyTorch fallback path.
Description
ExLlamaV2MoEMLP extends ExLlamaV2Module and represents a single MoE MLP block within a transformer layer. It contains:
- post_attention_layernorm: An RMSNorm or LayerNorm applied to the residual stream before routing.
- gate: A linear projection from hidden_size to num_experts that produces router logits.
- w1 (gate projection), w2 (down projection), w3 (up projection): Lists of ExLlamaV2Linear modules, one per expert. Each expert follows the standard gated-MLP pattern:
output = w2(act(w1(x)) * w3(x)).
The forward() method dispatches to the optimised C++/CUDA kernel (ext_c.q_moe_mlp_forward_) when all of the following conditions are met: the weights are quantized (q_handle is not None), no intermediates are requested, the input batch is small (batch_size * seq_len <= 4), the number of experts is one of {4, 8, 16, 128}, and no LoRA adapters are active. Otherwise, it falls back to forward_torch().
The forward_torch() path performs: (1) layer normalization, (2) router logit computation via the gate linear, (3) softmax + top-K selection to choose num_experts_per_token experts per token, (4) normalization of routing weights, (5) per-expert gated MLP computation with the configured activation function (SiLU or GELU), and (6) weighted aggregation of expert outputs back into the hidden state with a residual connection.
The load() method initializes either the quantized CUDA handle (allocating scratch buffers for intermediate computations) or loads plain torch weights. unload() frees the CUDA handle and releases all submodule weights.
Usage
ExLlamaV2MoEMLP is instantiated automatically by the model loader when the architecture specifies a Mixture-of-Experts MLP (e.g. Mixtral, DBRX, Qwen-MoE). It is not typically constructed directly by users. It appears as a layer in the model's module list and is called during the standard forward pass.
Code Reference
Source Location
- Repository: Turboderp_org_Exllamav2
- File: exllamav2/moe_mlp.py
- Lines: 15-365
Signature
class ExLlamaV2MoEMLP(ExLlamaV2Module):
name: str = "MoE MLP"
layer_idx: int
post_attention_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm
w1: list # Gate projection per expert
w2: list # Down projection per expert
w3: list # Up projection per expert
gate: ExLlamaV2Linear # Router linear
num_experts: int
num_experts_per_token: int
q_handle: int | None
def __init__(
self,
model: ExLlamaV2,
key: str,
layer_idx: int,
archparams=None
): ...
def forward(
self,
hidden_states: torch.Tensor,
cache=None,
attn_params=None,
past_len=None,
intermediates: bool = False,
loras: list[ExLlamaV2Lora] | None = None,
**kwargs
) -> torch.Tensor | dict[str, torch.Tensor]: ...
def forward_torch(
self,
hidden_states: torch.Tensor,
cache=None,
attn_params=None,
past_len=None,
intermediates=False,
loras: list[ExLlamaV2Lora] | None = None,
**kwargs
) -> torch.Tensor | dict[str, torch.Tensor]: ...
Import
from exllamav2.moe_mlp import ExLlamaV2MoEMLP
I/O Contract
Inputs (forward)
| Name | Type | Required | Description |
|---|---|---|---|
| hidden_states | torch.Tensor | Yes | Input tensor of shape (batch_size, sequence_length, hidden_dim) from the preceding attention layer |
| cache | ExLlamaV2CacheBase or None | No | KV cache object (unused by MLP but accepted for interface consistency) |
| attn_params | Params or None | No | Attention parameters (unused by MLP but accepted for interface consistency) |
| past_len | int or None | No | Past sequence length (unused by MLP) |
| intermediates | bool | No (default False) | If True, return a dict containing intermediate activations (post_norm, pre_down per expert, hidden_states) |
| loras | list[ExLlamaV2Lora] or None | No | LoRA adapters to apply to the gate and expert linear layers (forces PyTorch fallback) |
Outputs
| Name | Type | Description |
|---|---|---|
| hidden_states | torch.Tensor | Output tensor of shape (batch_size, sequence_length, hidden_dim) with the MoE residual added |
| result (when intermediates=True) | dict[str, torch.Tensor] | Dictionary with keys "post_norm", "pre_down.{expert_idx}", and "hidden_states" |
Usage Examples
Basic Usage (within model forward)
# ExLlamaV2MoEMLP is called internally during model.forward().
# Direct usage mirrors any ExLlamaV2Module:
hidden_states = moe_mlp_layer.forward(
hidden_states,
cache=cache,
attn_params=attn_params,
past_len=past_len
)
Inspecting Intermediate Activations
# Request intermediates to examine routing behaviour
result = moe_mlp_layer.forward(
hidden_states,
intermediates=True
)
post_norm = result["post_norm"] # After layernorm, before routing
expert_0_pre_down = result["pre_down.0"] # Expert 0 activations before down-proj
final_output = result["hidden_states"] # Final output with residual
Internal Architecture
The forward_torch() execution path follows this sequence:
hidden_states (batch, seq, hidden_dim)
|
v
post_attention_layernorm (RMSNorm or LayerNorm)
|
v
gate linear -> router_logits (batch*seq, num_experts)
|
v
softmax -> routing_weights
|
v
topk(num_experts_per_token) -> selected_experts, routing_weights
|
v
normalize routing_weights (sum to 1.0)
|
v
for each expert:
gather tokens assigned to this expert
w1 (gate proj) -> act_fn (SiLU or GELU)
w3 (up proj)
element-wise multiply
w2 (down proj)
scale by routing_weight
scatter-add to output
|
v
output + residual -> final_hidden_states
Key Methods
| Method | Description |
|---|---|
| __init__(model, key, layer_idx, archparams) | Initializes layernorm, gate linear, and per-expert w1/w2/w3 linear layers |
| load() | Loads all submodule weights; creates quantized CUDA handle with scratch buffers if weights are quantized |
| unload() | Frees CUDA handle and unloads all submodule weights |
| forward(hidden_states, ...) | Dispatches to CUDA kernel or forward_torch based on quantization state and input size |
| forward_torch(hidden_states, ...) | Pure PyTorch implementation: layernorm, routing, per-expert gated MLP, weighted aggregation + residual |
| weight_footprint() | Returns total memory footprint of all submodule weights in bytes |
| numel() | Returns total number of elements across w1, w2, w3 expert weights |
| set_device_idx(idx) | Propagates device index to all submodules |
| is_quant() | Returns True if the CUDA quantized handle is active |
| rank_reduce(k) | Applies rank reduction to all expert weight matrices |