Implementation:Lucidrains X transformers MultiInputTransformerWrapper
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Multi_Modal, Model_Architecture |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for wrapping transformer attention layers to accept multiple named input token streams with separate embeddings and output heads provided by the x-transformers library.
Description
The MultiInputTransformerWrapper accepts a dictionary of named token tensors as input instead of a single token tensor. Each named input has its own embedding table, and the embeddings are summed (similar to BERT's token + segment embeddings). Positional embeddings are added on top. The output can be returned as embeddings or as a dictionary of logits, one per named input. This is useful for multi-modal or multi-stream architectures where multiple types of tokens are processed jointly. It supports memory tokens, KV caching, gradient fraction scaling (from CogView/GLM-130B), and all standard TransformerWrapper features.
Usage
Import this class when your model needs to accept multiple types of token inputs simultaneously (e.g., token IDs and type IDs, or text and image tokens). Each input stream gets its own embedding and output head, while sharing the same attention layers.
Code Reference
Source Location
- Repository: Lucidrains_X_transformers
- File: x_transformers/multi_input.py
- Lines: 34-277
Signature
class MultiInputTransformerWrapper(Module):
def __init__(
self,
*,
num_tokens: Dict[str, int] = dict(),
max_seq_len,
attn_layers: AttentionLayers,
emb_dim = None,
max_mem_len = 0,
shift_mem_down = 0,
emb_dropout = 0.,
post_emb_norm = False,
num_memory_tokens = None,
memory_tokens_interspersed_every = None,
return_only_embed = False,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False,
emb_frac_gradient = 1.,
attn_z_loss_weight = 1e-4,
):
"""
Args:
num_tokens: Dict mapping input name to vocabulary size (e.g., {"token": 30000, "type": 4}).
max_seq_len: Maximum sequence length.
attn_layers: AttentionLayers instance (Encoder or Decoder).
emb_dim: Embedding dimension override.
max_mem_len: Maximum memory length for Transformer-XL recurrence.
emb_dropout: Dropout rate on embeddings.
post_emb_norm: Apply LayerNorm after embedding.
num_memory_tokens: Number of learnable memory tokens.
return_only_embed: Only return embeddings, no logit heads.
emb_frac_gradient: Fraction of gradient to embedding (CogView technique).
"""
Import
from x_transformers.multi_input import MultiInputTransformerWrapper
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | Dict[str, Tensor (b, n)] | Yes | Dictionary of named token ID tensors, one per input stream |
| mask | Tensor (b, n) | No | Boolean attention mask |
| return_embeddings | bool | No | Return raw embeddings instead of logits |
Outputs
| Name | Type | Description |
|---|---|---|
| forward() default | Dict[str, Tensor (b, n, vocab)] | Dictionary of logits, one per named input stream |
| forward() with return_embeddings | Tensor (b, n, d) | Raw transformer embeddings |
Usage Examples
BERT-style Token + Type Inputs
import torch
from x_transformers import Encoder
from x_transformers.multi_input import MultiInputTransformerWrapper
model = MultiInputTransformerWrapper(
num_tokens={"token": 30000, "segment": 2},
max_seq_len=512,
attn_layers=Encoder(dim=256, depth=6, heads=8)
)
# Provide multiple named inputs
x = {
"token": torch.randint(0, 30000, (4, 128)),
"segment": torch.randint(0, 2, (4, 128))
}
logits = model(x)
# logits["token"]: (4, 128, 30000)
# logits["segment"]: (4, 128, 2)
Embedding-Only Mode
embed_model = MultiInputTransformerWrapper(
num_tokens={"token": 30000, "type": 4},
max_seq_len=512,
attn_layers=Encoder(dim=256, depth=6, heads=8),
return_only_embed=True
)
embeddings = embed_model(x)
# embeddings: (4, 128, 256)