Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Axolotl ai cloud Axolotl Attention Mechanism Selection

From Leeroopedia





Knowledge Sources
Domains Optimization, LLMs, GPU_Computing
Last Updated 2026-02-06 22:33 GMT

Overview

Decision framework for selecting the correct attention mechanism and cross-entropy optimization in Axolotl, with mutual exclusivity rules and feature compatibility constraints.

Description

Axolotl supports five attention mechanisms (Flash Attention, SDP Attention, Flex Attention, Xformers, S2 Attention) and four cross-entropy optimizations (cut_cross_entropy, chunked_cross_entropy, liger_cross_entropy, liger_fused_linear_cross_entropy). Both categories enforce mutual exclusivity: only one can be active at a time. Additionally, certain features like batch flattening and sample packing have hard dependencies on specific attention mechanisms.

Usage

Apply this decision framework when choosing attention and cross-entropy settings in the training configuration. Key decision point: if using `sample_packing`, you MUST choose an optimized attention mechanism.

The Insight (Rule of Thumb)

  • Rule 1 - Mutual Exclusivity (Attention): Only ONE of `flash_attention`, `sdp_attention`, `flex_attention`, `xformers_attention`, `s2_attention` can be enabled. Setting more than one raises a ValueError.
  • Rule 2 - Mutual Exclusivity (Cross-Entropy): Only ONE of `cut_cross_entropy`, `chunked_cross_entropy`, `liger_cross_entropy`, `liger_fused_linear_cross_entropy` can be enabled.
  • Rule 3 - Sample Packing Requires Attention: `sample_packing: true` requires one of Flash, SDP, Flex, or Xformers attention. Note: S2 Attention does NOT qualify for sample packing decontamination.
  • Rule 4 - Batch Flattening Requires Flash: `batch_flattening` requires `flash_attention: true` specifically (not other attention mechanisms).
  • Rule 5 - Batch Flattening Size: `batch_flattening` has no effect when `micro_batch_size == 1`.
  • Rule 6 - Flash Attention Patching: Flash Attention monkeypatches the model's attention forward pass. This disables `output_attentions` (returns None instead).
  • Action - Default Recommendation: Use `flash_attention: true` when available (requires `flash-attn` pip package and NVIDIA GPU). It provides the best combination of speed, memory efficiency, and feature compatibility.
  • Trade-off: Flash Attention is fastest but requires NVIDIA GPU and separate pip install. SDP is built into PyTorch (no extra install) but slightly slower. Flex Attention is newer and supports custom attention patterns.

Reasoning

Each attention mechanism uses a different implementation for computing self-attention, and they cannot be composed. Flash Attention uses Tri Dao's CUDA kernel that fuses the softmax computation with the matmul, avoiding materializing the full attention matrix. SDP uses PyTorch's built-in `scaled_dot_product_attention` which auto-selects the best backend. Xformers uses Meta's memory-efficient attention implementation. Flex Attention uses PyTorch's `torch.nn.attention.flex_attention` for custom attention patterns.

The sample packing decontamination requirement exists because packed sequences contain multiple training examples concatenated together. Without proper attention masking (which only optimized attention mechanisms provide), tokens from one example would attend to tokens from another, corrupting the training signal.

Code Evidence

Attention mutual exclusivity from `src/axolotl/utils/schemas/validation.py:163-177`:

@model_validator(mode="before")
@classmethod
def check_attention_fields(cls, data):
    fields = (
        "xformers_attention",
        "sdp_attention",
        "s2_attention",
        "flash_attention",
        "flex_attention",
    )
    non_empty_count = sum(1 for field in fields if data.get(field))
    if non_empty_count > 1:
        raise ValueError(f"Only one of {', '.join(fields)} must be set")

Cross-entropy mutual exclusivity from `src/axolotl/utils/schemas/validation.py:827-855`:

@model_validator(mode="before")
@classmethod
def check_cross_entropy_conflicts(cls, data):
    ce_options = {
        "cut_cross_entropy": data.get("cut_cross_entropy"),
        "chunked_cross_entropy": data.get("chunked_cross_entropy"),
        "liger_cross_entropy": data.get("liger_cross_entropy"),
        "liger_fused_linear_cross_entropy": data.get("liger_fused_linear_cross_entropy"),
    }
    enabled_options = [k for k, v in ce_options.items() if v]
    if len(enabled_options) > 1:
        raise ValueError(
            f"Only one cross entropy optimization can be enabled at a time. "
            f"Found {len(enabled_options)} enabled: {', '.join(enabled_options)}."
        )

Batch flattening requirements from `src/axolotl/utils/schemas/validation.py:794-814`:

if data.get("batch_flattening"):
    batch_flattening_auto = data.get("batch_flattening") == "auto"
    if not data.get("flash_attention") and not batch_flattening_auto:
        raise ValueError("batch_flattening requires flash attention")
    if data.get("sample_packing") and not batch_flattening_auto:
        raise ValueError("batch_flattening not compatible with sample_packing")
    if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
        LOG.warning("batch_flattening has no effect with micro_batch_size == 1")

Output attentions limitation from `src/axolotl/monkeypatch/llama_attn_hijack_flash.py:171-175`:

if output_attentions:
    warnings.warn(
        "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.",
        stacklevel=2,
    )

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment