Heuristic:Axolotl ai cloud Axolotl Attention Mechanism Selection
| Knowledge Sources | |
|---|---|
| Domains | Optimization, LLMs, GPU_Computing |
| Last Updated | 2026-02-06 22:33 GMT |
Overview
Decision framework for selecting the correct attention mechanism and cross-entropy optimization in Axolotl, with mutual exclusivity rules and feature compatibility constraints.
Description
Axolotl supports five attention mechanisms (Flash Attention, SDP Attention, Flex Attention, Xformers, S2 Attention) and four cross-entropy optimizations (cut_cross_entropy, chunked_cross_entropy, liger_cross_entropy, liger_fused_linear_cross_entropy). Both categories enforce mutual exclusivity: only one can be active at a time. Additionally, certain features like batch flattening and sample packing have hard dependencies on specific attention mechanisms.
Usage
Apply this decision framework when choosing attention and cross-entropy settings in the training configuration. Key decision point: if using `sample_packing`, you MUST choose an optimized attention mechanism.
The Insight (Rule of Thumb)
- Rule 1 - Mutual Exclusivity (Attention): Only ONE of `flash_attention`, `sdp_attention`, `flex_attention`, `xformers_attention`, `s2_attention` can be enabled. Setting more than one raises a ValueError.
- Rule 2 - Mutual Exclusivity (Cross-Entropy): Only ONE of `cut_cross_entropy`, `chunked_cross_entropy`, `liger_cross_entropy`, `liger_fused_linear_cross_entropy` can be enabled.
- Rule 3 - Sample Packing Requires Attention: `sample_packing: true` requires one of Flash, SDP, Flex, or Xformers attention. Note: S2 Attention does NOT qualify for sample packing decontamination.
- Rule 4 - Batch Flattening Requires Flash: `batch_flattening` requires `flash_attention: true` specifically (not other attention mechanisms).
- Rule 5 - Batch Flattening Size: `batch_flattening` has no effect when `micro_batch_size == 1`.
- Rule 6 - Flash Attention Patching: Flash Attention monkeypatches the model's attention forward pass. This disables `output_attentions` (returns None instead).
- Action - Default Recommendation: Use `flash_attention: true` when available (requires `flash-attn` pip package and NVIDIA GPU). It provides the best combination of speed, memory efficiency, and feature compatibility.
- Trade-off: Flash Attention is fastest but requires NVIDIA GPU and separate pip install. SDP is built into PyTorch (no extra install) but slightly slower. Flex Attention is newer and supports custom attention patterns.
Reasoning
Each attention mechanism uses a different implementation for computing self-attention, and they cannot be composed. Flash Attention uses Tri Dao's CUDA kernel that fuses the softmax computation with the matmul, avoiding materializing the full attention matrix. SDP uses PyTorch's built-in `scaled_dot_product_attention` which auto-selects the best backend. Xformers uses Meta's memory-efficient attention implementation. Flex Attention uses PyTorch's `torch.nn.attention.flex_attention` for custom attention patterns.
The sample packing decontamination requirement exists because packed sequences contain multiple training examples concatenated together. Without proper attention masking (which only optimized attention mechanisms provide), tokens from one example would attend to tokens from another, corrupting the training signal.
Code Evidence
Attention mutual exclusivity from `src/axolotl/utils/schemas/validation.py:163-177`:
@model_validator(mode="before")
@classmethod
def check_attention_fields(cls, data):
fields = (
"xformers_attention",
"sdp_attention",
"s2_attention",
"flash_attention",
"flex_attention",
)
non_empty_count = sum(1 for field in fields if data.get(field))
if non_empty_count > 1:
raise ValueError(f"Only one of {', '.join(fields)} must be set")
Cross-entropy mutual exclusivity from `src/axolotl/utils/schemas/validation.py:827-855`:
@model_validator(mode="before")
@classmethod
def check_cross_entropy_conflicts(cls, data):
ce_options = {
"cut_cross_entropy": data.get("cut_cross_entropy"),
"chunked_cross_entropy": data.get("chunked_cross_entropy"),
"liger_cross_entropy": data.get("liger_cross_entropy"),
"liger_fused_linear_cross_entropy": data.get("liger_fused_linear_cross_entropy"),
}
enabled_options = [k for k, v in ce_options.items() if v]
if len(enabled_options) > 1:
raise ValueError(
f"Only one cross entropy optimization can be enabled at a time. "
f"Found {len(enabled_options)} enabled: {', '.join(enabled_options)}."
)
Batch flattening requirements from `src/axolotl/utils/schemas/validation.py:794-814`:
if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto"
if not data.get("flash_attention") and not batch_flattening_auto:
raise ValueError("batch_flattening requires flash attention")
if data.get("sample_packing") and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
LOG.warning("batch_flattening has no effect with micro_batch_size == 1")
Output attentions limitation from `src/axolotl/monkeypatch/llama_attn_hijack_flash.py:171-175`:
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.",
stacklevel=2,
)