Environment:OpenRLHF OpenRLHF Flash Attention Environment
| Knowledge Sources | |
|---|---|
| Domains | Infrastructure, Optimization, Deep_Learning |
| Last Updated | 2026-02-07 10:00 GMT |
Overview
Flash Attention 2.8.3 with optional ring-flash-attn and liger-kernel for memory-efficient attention and cross-entropy computation.
Description
This environment provides the Flash Attention library and related optional acceleration packages used by OpenRLHF. Flash Attention 2 is the default attention implementation for all model loading, providing memory-efficient attention computation. The library also provides a Triton-based cross-entropy kernel used for efficient log-probability computation. Ring Flash Attention extends this to support sequence parallelism across GPUs for long-context training. Liger Kernel provides alternative fused operations.
Usage
Use this environment when training with packing samples (mandatory requirement), when computing log probabilities for loss computation (optional but faster), or when training with ring attention for long-context sequences. Flash Attention 2 is the default `--attn_implementation` for all model loading.
System Requirements
| Category | Requirement | Notes |
|---|---|---|
| GPU | NVIDIA GPU with compute capability >= 7.0 | Ampere (sm_80+) recommended |
| GPU Memory | Model-dependent | Flash Attention reduces peak memory by not materializing attention matrix |
Dependencies
Python Packages
- `flash-attn` == 2.8.3 (pinned in requirements.txt)
- `ring-flash-attn` (optional; install extra: `pip install openrlhf[ring]`)
- `liger-kernel` (optional; install extra: `pip install openrlhf[liger]`)
Credentials
No additional credentials required.
Quick Install
# Core Flash Attention
pip install flash-attn==2.8.3
# Ring Attention for long-context training (optional)
pip install ring_flash_attn
# Liger Kernel for fused operations (optional)
pip install liger_kernel
Code Evidence
Default attention implementation from `openrlhf/models/actor.py:41`:
attn_implementation="flash_attention_2",
Flash Attention cross-entropy fallback from `openrlhf/models/utils.py:92-101`:
try:
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
output = cross_entropy_loss(logits.reshape(-1, last_dim), labels.reshape(-1))
log_probs_labels = -output[0].view(*batch_dim)
except ImportError:
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
logsumexp_values = _logsumexp_by_chunk(logits.reshape(-1, last_dim))
logsumexp_values = logsumexp_values.view(*batch_dim)
log_probs_labels = logits_labels - logsumexp_values
Packing requires flash attention warning from `openrlhf/cli/train_dpo.py:314-316`:
if args.packing_samples and "flash_attention" not in args.attn_implementation:
print("[Warning] Please use --attn_implementation with flash_attention...")
args.attn_implementation = "flash_attention_2"
Ring attention setup from `openrlhf/utils/deepspeed/deepspeed.py:125-128`:
from ring_flash_attn import substitute_hf_flash_attn
self.ring_head_stride = getattr(self.args, "ring_head_stride", 1)
substitute_hf_flash_attn(self.ring_attn_group, self.ring_head_stride)
Liger Kernel integration from `openrlhf/models/actor.py:85-88`:
if use_liger_kernel:
from liger_kernel.transformers import AutoLigerKernelForCausalLM
model_class = AutoLigerKernelForCausalLM
Deterministic flash attention from `openrlhf/utils/deepspeed/deepspeed.py:83-84`:
# Use deterministic backward in flash attention as, by default, flash attention uses atomic adds
transformers.modeling_flash_attention_utils.deterministic_g = True
Common Errors
| Error Message | Cause | Solution |
|---|---|---|
| `ImportError: flash_attn` | Flash Attention not installed | `pip install flash-attn==2.8.3` |
| `[Warning] Please use --attn_implementation with flash_attention` | Packing samples without flash attention | Add `--attn_implementation flash_attention_2` |
| `ImportError: ring_flash_attn` | Ring attention requested but not installed | `pip install ring_flash_attn` |
| `ImportError: liger_kernel` | Liger kernel requested but not installed | `pip install liger_kernel` |
Compatibility Notes
- Fallback: Flash Attention's cross-entropy kernel is optional; code falls back to chunked logsumexp computation when not available.
- Packing Samples: Using `--packing_samples` automatically forces `flash_attention_2` as the attention implementation.
- Ring Attention: Requires both `flash-attn` and `ring-flash-attn` packages. Uses the sequence parallel mesh dimension.
- Determinism: Flash Attention uses atomic adds by default (non-deterministic); `--full_determinism` enables deterministic backward.