Implementation:Lucidrains X transformers NeoMLP
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Model_Architecture, Neural_Networks |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for replacing traditional MLPs with a transformer-based message-passing network over input, hidden, and output node embeddings provided by the x-transformers library.
Description
The NeoMLP class reimagines the standard multi-layer perceptron as a transformer operating on a graph of "nodes". Each input dimension, hidden dimension, and output dimension is represented as a learnable token (node embedding). The input features are encoded with random Fourier features and added to the input node embeddings. All nodes (input, hidden, output) are concatenated into a single sequence and processed by a transformer encoder, which acts as message passing on a fully connected graph. The output node embeddings are then projected to produce the final output. This architecture replaces fixed-weight matrix multiplications with learned attention-based interactions.
Usage
Import this class when you want to experiment with replacing standard MLPs in your architecture with transformer-based alternatives. This is relevant for settings where the MLP needs to learn complex input-output relationships and the inductive bias of attention-based message passing may be beneficial, such as in novel view synthesis (LVSM) or function approximation tasks.
Code Reference
Source Location
- Repository: Lucidrains_X_transformers
- File: x_transformers/neo_mlp.py
- Lines: 42-135
Signature
class NeoMLP(Module):
def __init__(
self,
*,
dim_in,
dim_hidden,
dim_out,
dim_model,
depth,
encoder_kwargs: dict = dict(
attn_dim_head = 16,
heads = 4
)
):
"""
Args:
dim_in: Number of input dimensions (each becomes a node).
dim_hidden: Number of hidden nodes.
dim_out: Number of output dimensions (each becomes a node).
dim_model: Transformer model dimension for node embeddings.
depth: Depth of the transformer encoder.
encoder_kwargs: Additional kwargs for the Encoder (attn_dim_head, heads, etc.).
"""
Import
from x_transformers.neo_mlp import NeoMLP
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | Tensor (b, dim_in) or (dim_in,) | Yes | Input feature vector |
| return_embeds | bool | No | Also return node embeddings after processing |
Outputs
| Name | Type | Description |
|---|---|---|
| forward() returns | Tensor (b, dim_out) | Output prediction |
| forward() with return_embeds | (Tensor, (Tensor, Tensor, Tensor)) | Output plus (input_embed, hidden_embed, output_embed) |
Usage Examples
Basic Usage
import torch
from x_transformers.neo_mlp import NeoMLP
mlp = NeoMLP(
dim_in=64,
dim_hidden=128,
dim_out=32,
dim_model=64,
depth=3
)
x = torch.randn(8, 64)
output = mlp(x)
# output.shape == (8, 32)
Inspect Node Embeddings
output, (input_embed, hidden_embed, output_embed) = mlp(x, return_embeds=True)
# input_embed.shape == (8, 64, 64) # batch, dim_in nodes, dim_model
# hidden_embed.shape == (8, 128, 64) # batch, dim_hidden nodes, dim_model
# output_embed.shape == (8, 32, 64) # batch, dim_out nodes, dim_model