Heuristic:Hpcaitech ColossalAI Flash Attention Dtype Restriction
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Memory_Management |
| Last Updated | 2026-02-09 03:00 GMT |
Overview
Flash attention does not support fp32. The default tensor dtype must be set to fp16 or bf16 before calling `booster.boost()`, then reset afterward.
Description
Flash attention kernels are implemented only for half-precision floating-point types (float16 and bfloat16). If the default PyTorch dtype is float32 when `booster.boost()` is called, flash attention will be silently disabled, falling back to standard attention with significantly higher memory usage. The ColossalAI training script temporarily sets the default dtype to the chosen mixed-precision type before boosting the model, then restores it to float32 afterward.
Usage
Apply this whenever using flash attention with ColossalAI's Booster. This is relevant for all large model training where flash attention is desired for memory and speed benefits.
The Insight (Rule of Thumb)
- Action: Call `torch.set_default_dtype(torch.float16)` or `torch.set_default_dtype(torch.bfloat16)` before `booster.boost()`.
- Restore: Call `torch.set_default_dtype(torch.float32)` after `booster.boost()` to avoid side effects on subsequent tensor operations.
- Condition: Only needed when flash attention is enabled; if using standard attention, fp32 is fine.
- Warning: If flash attention is silently disabled, the only symptom is higher-than-expected VRAM usage or OOM errors during the forward pass.
Reasoning
Flash attention (FlashAttention-2) operates on half-precision CUDA tensors for performance. The underlying Triton/CUDA kernels do not have fp32 implementations. When PyTorch's default dtype is float32, newly created tensors (including internal attention buffers) are float32, causing flash attention dispatch to fail silently and fall back to the standard O(N^2) memory attention. Temporarily setting the default dtype ensures all tensors created during `booster.boost()` model wrapping are in the correct precision.
Code Evidence
From `applications/Colossal-LLaMA/train.py:226-234`:
# Flash attention will be disabled because it does NOT support fp32.
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
After boosting, the default dtype is restored (subsequent code operates with float32 default).