Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers MultiInputTransformerWrapper

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Multi_Modal, Model_Architecture
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for wrapping transformer attention layers to accept multiple named input token streams with separate embeddings and output heads provided by the x-transformers library.

Description

The MultiInputTransformerWrapper accepts a dictionary of named token tensors as input instead of a single token tensor. Each named input has its own embedding table, and the embeddings are summed (similar to BERT's token + segment embeddings). Positional embeddings are added on top. The output can be returned as embeddings or as a dictionary of logits, one per named input. This is useful for multi-modal or multi-stream architectures where multiple types of tokens are processed jointly. It supports memory tokens, KV caching, gradient fraction scaling (from CogView/GLM-130B), and all standard TransformerWrapper features.

Usage

Import this class when your model needs to accept multiple types of token inputs simultaneously (e.g., token IDs and type IDs, or text and image tokens). Each input stream gets its own embedding and output head, while sharing the same attention layers.

Code Reference

Source Location

Signature

class MultiInputTransformerWrapper(Module):
    def __init__(
        self,
        *,
        num_tokens: Dict[str, int] = dict(),
        max_seq_len,
        attn_layers: AttentionLayers,
        emb_dim = None,
        max_mem_len = 0,
        shift_mem_down = 0,
        emb_dropout = 0.,
        post_emb_norm = False,
        num_memory_tokens = None,
        memory_tokens_interspersed_every = None,
        return_only_embed = False,
        use_abs_pos_emb = True,
        scaled_sinu_pos_emb = False,
        emb_frac_gradient = 1.,
        attn_z_loss_weight = 1e-4,
    ):
        """
        Args:
            num_tokens: Dict mapping input name to vocabulary size (e.g., {"token": 30000, "type": 4}).
            max_seq_len: Maximum sequence length.
            attn_layers: AttentionLayers instance (Encoder or Decoder).
            emb_dim: Embedding dimension override.
            max_mem_len: Maximum memory length for Transformer-XL recurrence.
            emb_dropout: Dropout rate on embeddings.
            post_emb_norm: Apply LayerNorm after embedding.
            num_memory_tokens: Number of learnable memory tokens.
            return_only_embed: Only return embeddings, no logit heads.
            emb_frac_gradient: Fraction of gradient to embedding (CogView technique).
        """

Import

from x_transformers.multi_input import MultiInputTransformerWrapper

I/O Contract

Inputs

Name Type Required Description
x Dict[str, Tensor (b, n)] Yes Dictionary of named token ID tensors, one per input stream
mask Tensor (b, n) No Boolean attention mask
return_embeddings bool No Return raw embeddings instead of logits

Outputs

Name Type Description
forward() default Dict[str, Tensor (b, n, vocab)] Dictionary of logits, one per named input stream
forward() with return_embeddings Tensor (b, n, d) Raw transformer embeddings

Usage Examples

BERT-style Token + Type Inputs

import torch
from x_transformers import Encoder
from x_transformers.multi_input import MultiInputTransformerWrapper

model = MultiInputTransformerWrapper(
    num_tokens={"token": 30000, "segment": 2},
    max_seq_len=512,
    attn_layers=Encoder(dim=256, depth=6, heads=8)
)

# Provide multiple named inputs
x = {
    "token": torch.randint(0, 30000, (4, 128)),
    "segment": torch.randint(0, 2, (4, 128))
}

logits = model(x)
# logits["token"]: (4, 128, 30000)
# logits["segment"]: (4, 128, 2)

Embedding-Only Mode

embed_model = MultiInputTransformerWrapper(
    num_tokens={"token": 30000, "type": 4},
    max_seq_len=512,
    attn_layers=Encoder(dim=256, depth=6, heads=8),
    return_only_embed=True
)

embeddings = embed_model(x)
# embeddings: (4, 128, 256)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment