Implementation:Lucidrains X transformers Paired Sequence Generator Pattern
Appearance
| Field | Value |
|---|---|
| Repo | x-transformers |
| Domains | Data_Engineering, NLP |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Pattern specification for creating paired source-target data generators for encoder-decoder training, as demonstrated in the x-transformers training examples.
Description
This is a Pattern Doc. Users create a generator function that yields (src, tgt, src_mask) tuples for encoder-decoder training. The reference implementation is the cycle() generator from train_copy.py.
The pattern requires:
- A generator (or iterator) that yields 3-tuples.
- The source tensor
srccontains encoder input token IDs. - The target tensor
tgtcontains decoder input/output token IDs, typically with a prefix token prepended. - The source mask
src_maskis a boolean tensor indicating valid (non-padding) source positions.
Code Reference
File: train_copy.py, Lines: L19-25
Interface Specification
cycle() Generator (Reference Implementation)
def cycle():
while True:
prefix = torch.ones((BATCH_SIZE, 1)).long().to(DEVICE)
src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().to(DEVICE)
tgt = torch.cat((prefix, src, src), 1)
src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().to(DEVICE)
yield (src, tgt, src_mask)
Key details:
- The prefix is a tensor of ones (token ID 1), serving as the start-of-sequence token for the decoder.
- In this copy task, the target is the source concatenated with itself, preceded by the prefix:
[prefix, src, src]. - The source mask is all
Truesince there is no padding in this synthetic example. - Token IDs start from 2 (IDs 0 and 1 are reserved for padding and the prefix token, respectively).
- The generator yields batches directly (not individual samples), so no
DataLoaderis needed.
Usage in Training Loop
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
src, tgt, src_mask = next(cycle())
loss = model(src, tgt, mask=src_mask)
loss.backward()
optim.step()
optim.zero_grad()
Input / Output
| Direction | Name | Type | Shape | Description |
|---|---|---|---|---|
| Output | src |
LongTensor |
(B, enc_seq_len) |
Source token IDs for the encoder |
| Output | tgt |
LongTensor |
(B, dec_seq_len) |
Target token IDs for the decoder (with prefix) |
| Output | src_mask |
BoolTensor |
(B, enc_seq_len) |
Boolean mask; True = valid source position
|
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment