Principle:Hiyouga LLaMA Factory Shifted Sparse Attention
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Long-Context Modeling |
| Last Updated | 2026-02-06 19:00 GMT |
Overview
Shifted Sparse Attention (S2-Attn) is a training-time approximation to full attention that splits the sequence into groups and applies a cyclic shift to half the attention heads, enabling efficient fine-tuning on long contexts while preserving information flow across the entire sequence.
Description
Fine-tuning large language models on long contexts is computationally expensive because self-attention scales quadratically with sequence length (). LongLoRA proposes Shifted Sparse Attention as a drop-in replacement during training that reduces this cost by restricting attention to local groups while using shifted patterns to maintain global information exchange.
The mechanism works as follows:
- The sequence of length is divided into groups, each of size .
- The first half of the attention heads attend normally within each group.
- The second half of the attention heads have their key/value sequences cyclically shifted by half a group size before grouping.
- After attention computation, the shifted heads are shifted back to their original positions.
This shifting pattern ensures that tokens at group boundaries can attend to tokens in adjacent groups through the shifted heads, creating an overlapping attention pattern that approximates full attention while operating on smaller local windows.
In LLaMA-Factory, S2-Attn is implemented by monkey-patching the attention forward methods of LLaMA-family models. It supports all three attention backends: standard attention, FlashAttention-2, and SDPA (Scaled Dot-Product Attention). The group size ratio is fixed at 0.25 (each group is 1/4 of the sequence length), and S2-Attn is only active during training -- inference uses standard full attention.
Usage
Shifted Sparse Attention should be used when:
- Fine-tuning a model to handle longer contexts than its pretrained window (e.g., 4K to 32K tokens).
- Full-attention fine-tuning is too memory-intensive for the available hardware.
- Combined with LoRA adapters for parameter-efficient long-context fine-tuning (the recommended LongLoRA configuration).
- The model architecture is LLaMA-based (the implementation patches LLaMA attention classes specifically). Supported architectures are listed in
SUPPORTED_CLASS_FOR_S2ATTN.
Enable via the shift_attn flag in model arguments.
Theoretical Basis
Standard self-attention for a sequence of length with heads has complexity:
where is the head dimension. S2-Attn reduces this by partitioning into groups of size :
With the default group size ratio of 0.25, , giving:
This is a 4x reduction in attention computation compared to full attention.
The shifting mechanism ensures information flow across group boundaries. For a sequence with attention heads:
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, :num_heads // 2], # first half: no shift
state[:, :, num_heads // 2:].roll(-groupsz // 2, dims=1)), # second half: shifted
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, num_heads, head_dim).transpose(1, 2)
The first heads process tokens within their local group normally. The second heads have their sequences shifted by positions before grouping, so each group in the shifted view overlaps with two groups from the original view. After attention, the shift is reversed:
# After attention computation, shift back
attn_output = torch.cat(
(attn_output[:, :, :num_heads // 2],
attn_output[:, :, num_heads // 2:].roll(groupsz // 2, dims=1)),
dim=2,
)
This creates an overlapping window pattern where every token can reach any other token within at most 2 hops through the shifted heads, approximating full attention connectivity. The attention mask is correspondingly adjusted:
# For standard/SDPA attention
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
# For FlashAttention
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
The configuration sets a group_size_ratio attribute on the model config, which is checked at runtime to activate or bypass the shifting logic:
def configure_longlora(config, model_args, is_trainable):
if model_args.shift_attn:
setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch()
A key property is that S2-Attn is training-only: the self.training flag gates the shifting behavior, so at inference time the model reverts to standard full attention, preserving compatibility with all standard inference pipelines.