Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Hpcaitech ColossalAI Flash Attention Dtype Restriction

From Leeroopedia




Knowledge Sources
Domains Optimization, Memory_Management
Last Updated 2026-02-09 03:00 GMT

Overview

Flash attention does not support fp32. The default tensor dtype must be set to fp16 or bf16 before calling `booster.boost()`, then reset afterward.

Description

Flash attention kernels are implemented only for half-precision floating-point types (float16 and bfloat16). If the default PyTorch dtype is float32 when `booster.boost()` is called, flash attention will be silently disabled, falling back to standard attention with significantly higher memory usage. The ColossalAI training script temporarily sets the default dtype to the chosen mixed-precision type before boosting the model, then restores it to float32 afterward.

Usage

Apply this whenever using flash attention with ColossalAI's Booster. This is relevant for all large model training where flash attention is desired for memory and speed benefits.

The Insight (Rule of Thumb)

  • Action: Call `torch.set_default_dtype(torch.float16)` or `torch.set_default_dtype(torch.bfloat16)` before `booster.boost()`.
  • Restore: Call `torch.set_default_dtype(torch.float32)` after `booster.boost()` to avoid side effects on subsequent tensor operations.
  • Condition: Only needed when flash attention is enabled; if using standard attention, fp32 is fine.
  • Warning: If flash attention is silently disabled, the only symptom is higher-than-expected VRAM usage or OOM errors during the forward pass.

Reasoning

Flash attention (FlashAttention-2) operates on half-precision CUDA tensors for performance. The underlying Triton/CUDA kernels do not have fp32 implementations. When PyTorch's default dtype is float32, newly created tensors (including internal attention buffers) are float32, causing flash attention dispatch to fail silently and fall back to the standard O(N^2) memory attention. Temporarily setting the default dtype ensures all tensors created during `booster.boost()` model wrapping are in the correct precision.

Code Evidence

From `applications/Colossal-LLaMA/train.py:226-234`:

# Flash attention will be disabled because it does NOT support fp32.
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
    model=model,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    dataloader=dataloader,
)

After boosting, the default dtype is restored (subsequent code operates with float32 default).

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment