Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Lucidrains X transformers Numerical Stability Techniques

From Leeroopedia




Knowledge Sources
Domains Deep_Learning, Optimization
Last Updated 2026-02-08 18:00 GMT

Overview

Collection of numerical stability techniques implemented in x-transformers to prevent NaN gradients, attention overflow, and training instability.

Description

x-transformers implements multiple numerical stability measures across its attention and normalization modules. These techniques address common failure modes in transformer training: attention logit overflow, NaN from fully-masked rows, precision loss in mixed-precision training, and gradient saturation. Understanding when and why each technique is applied helps avoid destabilizing model configurations.

Usage

Use this heuristic when encountering NaN losses, training instability, or overflow errors during training, or when designing a custom configuration that may bypass default safety measures.

The Insight (Rule of Thumb)

  • QK Normalization: Set `qk_norm=True` to L2-normalize queries and keys, preventing attention overflow. Validated at 3B scale (SwinV2) and 8B scale (Persimmon). When enabled, softmax runs in native precision; when disabled, softmax is forced to float32.
  • Z-Loss: Apply auxiliary z-loss to attention logits for stabilization (from PaLM). Same technique used in MoE router logits.
  • Zero KV Token: Set `add_zero_kv=True` to add learnable zero key/value tokens, controlling attention outliers (Evan Miller's "attention is off by one").
  • Logit Softclamp: Set `softclamp_logits=True` with `logit_softclamp_value=50` to prevent extreme logit values via tanh-based clamping. Not compatible with flash attention.
  • QK Clipping: Set `qk_clip` for Muon training instability; default tau=100.
  • Masked Row Protection: Automatic detection and zeroing of outputs for fully-masked attention rows to prevent NaN gradients.
  • Unit Offset Norm: Set `norm_add_unit_offset=True` to allow safe weight decay on LayerNorm gamma parameters.
  • Gate Initialization: Gate biases initialized to +10 (sigmoid(10) nearly 1.0) for stable initial pass-through.

Reasoning

Transformer attention computes `softmax(QK^T / sqrt(d))` which can overflow in float16 when logit values are large. QK normalization bounds the dot product to [-1, 1], eliminating this failure mode entirely. The forced float32 softmax when QK norm is disabled is a fallback safety measure.

Fully-masked rows produce `softmax([-inf, -inf, ...])` = NaN, which propagates through the entire backward pass. The row masking protection detects this edge case and zeros the output preemptively.

The unit offset trick initializes LayerNorm gamma to 0 (with a +1 offset added in the forward pass), so weight decay pushes gamma toward 0 (equivalent to gamma=1 in standard formulation) rather than destabilizing training.

Code Evidence

Softmax precision control from `attend.py:233-234`:

softmax_fn = partial(F.softmax, dim = -1)
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn

Masked row NaN protection from `attend.py:404-410, 462-465`:

# protect against an entire row being masked out
row_is_entirely_masked = None
if exists(mask):
    row_is_entirely_masked = ~mask.any(dim = -1)

# ... later in forward:
if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
    out = out.masked_fill(row_is_entirely_masked[..., None], 0.)

Attention bias masking with halved mask value from `attend.py:420-427`:

mask_value = -torch.finfo(q.dtype).max
if exists(mask):
    attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)

Gate initialization for stable pass-through from `x_transformers.py:1605-1606, 1613-1614`:

nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 10)

QK clipping for Muon training from `x_transformers.py:1792-1794`:

tau = 100 # this hyperparameter controls how large the attention logits can be
""" proposed by the Moonshot AI team as a solution for Muon training instability """

Unit offset norm initialization from `x_transformers.py:861`:

nn.init.constant_(self.gamma, 1. - float(unit_offset))

Z-loss documentation from `x_transformers.py:212-214`:

# the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
# in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
# also used in PaLM as one of the measures

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment