Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers NeoMLP

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Model_Architecture, Neural_Networks
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for replacing traditional MLPs with a transformer-based message-passing network over input, hidden, and output node embeddings provided by the x-transformers library.

Description

The NeoMLP class reimagines the standard multi-layer perceptron as a transformer operating on a graph of "nodes". Each input dimension, hidden dimension, and output dimension is represented as a learnable token (node embedding). The input features are encoded with random Fourier features and added to the input node embeddings. All nodes (input, hidden, output) are concatenated into a single sequence and processed by a transformer encoder, which acts as message passing on a fully connected graph. The output node embeddings are then projected to produce the final output. This architecture replaces fixed-weight matrix multiplications with learned attention-based interactions.

Usage

Import this class when you want to experiment with replacing standard MLPs in your architecture with transformer-based alternatives. This is relevant for settings where the MLP needs to learn complex input-output relationships and the inductive bias of attention-based message passing may be beneficial, such as in novel view synthesis (LVSM) or function approximation tasks.

Code Reference

Source Location

Signature

class NeoMLP(Module):
    def __init__(
        self,
        *,
        dim_in,
        dim_hidden,
        dim_out,
        dim_model,
        depth,
        encoder_kwargs: dict = dict(
            attn_dim_head = 16,
            heads = 4
        )
    ):
        """
        Args:
            dim_in: Number of input dimensions (each becomes a node).
            dim_hidden: Number of hidden nodes.
            dim_out: Number of output dimensions (each becomes a node).
            dim_model: Transformer model dimension for node embeddings.
            depth: Depth of the transformer encoder.
            encoder_kwargs: Additional kwargs for the Encoder (attn_dim_head, heads, etc.).
        """

Import

from x_transformers.neo_mlp import NeoMLP

I/O Contract

Inputs

Name Type Required Description
x Tensor (b, dim_in) or (dim_in,) Yes Input feature vector
return_embeds bool No Also return node embeddings after processing

Outputs

Name Type Description
forward() returns Tensor (b, dim_out) Output prediction
forward() with return_embeds (Tensor, (Tensor, Tensor, Tensor)) Output plus (input_embed, hidden_embed, output_embed)

Usage Examples

Basic Usage

import torch
from x_transformers.neo_mlp import NeoMLP

mlp = NeoMLP(
    dim_in=64,
    dim_hidden=128,
    dim_out=32,
    dim_model=64,
    depth=3
)

x = torch.randn(8, 64)
output = mlp(x)
# output.shape == (8, 32)

Inspect Node Embeddings

output, (input_embed, hidden_embed, output_embed) = mlp(x, return_embeds=True)
# input_embed.shape == (8, 64, 64)   # batch, dim_in nodes, dim_model
# hidden_embed.shape == (8, 128, 64) # batch, dim_hidden nodes, dim_model
# output_embed.shape == (8, 32, 64)  # batch, dim_out nodes, dim_model

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment