Implementation:LLMBook zh LLMBook zh github io MoeLayer
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Model_Architecture |
| Last Updated | 2026-02-08 04:29 GMT |
Overview
Concrete tool for Mixture of Experts routing and computation provided by PyTorch as a custom nn.Module.
Description
The MoeLayer class implements a Mixture of Experts layer as a PyTorch module. It accepts a list of expert modules (typically feed-forward networks), a gating network, and a top-k parameter. During the forward pass, the gating network computes logits for each expert, the top-k experts are selected per token, their weights are normalized via softmax, and the weighted sum of expert outputs produces the final result. This implementation iterates over all experts and uses `torch.where` to find which tokens selected each expert, enabling sparse computation.
Usage
Import this class when implementing MoE-based Transformer architectures. MoeLayer replaces the standard dense feed-forward network in a Transformer block, allowing different tokens to be routed to different experts.
Code Reference
Source Location
- Repository: LLMBook-zh
- File: code/5.4 MoE.py
- Lines: 1-24
Signature
class MoeLayer(nn.Module):
def __init__(
self,
experts: List[nn.Module],
gate: nn.Module,
num_experts_per_tok: int
):
"""
Args:
experts: List of expert modules (e.g., feed-forward networks).
gate: Gating network that produces logits over experts.
num_experts_per_tok: Number of experts selected per token (top-k).
"""
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Args:
inputs: Input tensor of shape (batch * seq_len, hidden_dim).
Returns:
Weighted sum of selected expert outputs, same shape as input.
"""
Import
from typing import List
from torch import nn
import torch.nn.functional as F
# MoeLayer defined locally in code/5.4 MoE.py
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| experts | List[nn.Module] | Yes | List of expert networks (constructor) |
| gate | nn.Module | Yes | Gating/routing network (constructor) |
| num_experts_per_tok | int | Yes | Top-k experts per token (constructor) |
| inputs | torch.Tensor | Yes | Input tensor (batch*seq_len, hidden_dim) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | torch.Tensor | Weighted expert outputs, same shape as input |
Usage Examples
import torch
from torch import nn
hidden_dim = 512
num_experts = 8
top_k = 2
# Create expert networks and gate
experts = [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_experts)]
gate = nn.Linear(hidden_dim, num_experts)
# Instantiate MoE layer
moe = MoeLayer(experts=experts, gate=gate, num_experts_per_tok=top_k)
# Forward pass
inputs = torch.randn(32, hidden_dim) # 32 tokens
outputs = moe(inputs)
# outputs.shape == (32, 512)