Heuristic:Lucidrains X transformers Numerical Stability Techniques
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Collection of numerical stability techniques implemented in x-transformers to prevent NaN gradients, attention overflow, and training instability.
Description
x-transformers implements multiple numerical stability measures across its attention and normalization modules. These techniques address common failure modes in transformer training: attention logit overflow, NaN from fully-masked rows, precision loss in mixed-precision training, and gradient saturation. Understanding when and why each technique is applied helps avoid destabilizing model configurations.
Usage
Use this heuristic when encountering NaN losses, training instability, or overflow errors during training, or when designing a custom configuration that may bypass default safety measures.
The Insight (Rule of Thumb)
- QK Normalization: Set `qk_norm=True` to L2-normalize queries and keys, preventing attention overflow. Validated at 3B scale (SwinV2) and 8B scale (Persimmon). When enabled, softmax runs in native precision; when disabled, softmax is forced to float32.
- Z-Loss: Apply auxiliary z-loss to attention logits for stabilization (from PaLM). Same technique used in MoE router logits.
- Zero KV Token: Set `add_zero_kv=True` to add learnable zero key/value tokens, controlling attention outliers (Evan Miller's "attention is off by one").
- Logit Softclamp: Set `softclamp_logits=True` with `logit_softclamp_value=50` to prevent extreme logit values via tanh-based clamping. Not compatible with flash attention.
- QK Clipping: Set `qk_clip` for Muon training instability; default tau=100.
- Masked Row Protection: Automatic detection and zeroing of outputs for fully-masked attention rows to prevent NaN gradients.
- Unit Offset Norm: Set `norm_add_unit_offset=True` to allow safe weight decay on LayerNorm gamma parameters.
- Gate Initialization: Gate biases initialized to +10 (sigmoid(10) nearly 1.0) for stable initial pass-through.
Reasoning
Transformer attention computes `softmax(QK^T / sqrt(d))` which can overflow in float16 when logit values are large. QK normalization bounds the dot product to [-1, 1], eliminating this failure mode entirely. The forced float32 softmax when QK norm is disabled is a fallback safety measure.
Fully-masked rows produce `softmax([-inf, -inf, ...])` = NaN, which propagates through the entire backward pass. The row masking protection detects this edge case and zeros the output preemptively.
The unit offset trick initializes LayerNorm gamma to 0 (with a +1 offset added in the forward pass), so weight decay pushes gamma toward 0 (equivalent to gamma=1 in standard formulation) rather than destabilizing training.
Code Evidence
Softmax precision control from `attend.py:233-234`:
softmax_fn = partial(F.softmax, dim = -1)
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
Masked row NaN protection from `attend.py:404-410, 462-465`:
# protect against an entire row being masked out
row_is_entirely_masked = None
if exists(mask):
row_is_entirely_masked = ~mask.any(dim = -1)
# ... later in forward:
if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
Attention bias masking with halved mask value from `attend.py:420-427`:
mask_value = -torch.finfo(q.dtype).max
if exists(mask):
attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
Gate initialization for stable pass-through from `x_transformers.py:1605-1606, 1613-1614`:
nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 10)
QK clipping for Muon training from `x_transformers.py:1792-1794`:
tau = 100 # this hyperparameter controls how large the attention logits can be
""" proposed by the Moonshot AI team as a solution for Muon training instability """
Unit offset norm initialization from `x_transformers.py:861`:
nn.init.constant_(self.gamma, 1. - float(unit_offset))
Z-loss documentation from `x_transformers.py:212-214`:
# the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
# in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
# also used in PaLM as one of the measures
Related Pages
- Implementation:Lucidrains_X_transformers_TransformerWrapper_Decoder_Init
- Implementation:Lucidrains_X_transformers_TransformerWrapper_Encoder_Init
- Implementation:Lucidrains_X_transformers_AutoregressiveWrapper_Forward
- Principle:Lucidrains_X_transformers_Causal_Decoder_Configuration
- Principle:Lucidrains_X_transformers_Autoregressive_Training_Loss