Implementation:NVIDIA TransformerEngine Attention Backends
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Attention |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements the concrete attention backend strategies (unfused PyTorch, Flash Attention v2/v3, and cuDNN Fused Attention) with FP8 support for the TransformerEngine dot-product attention dispatcher.
Description
This module provides three backend classes that DotProductAttention dispatches to at runtime. UnfusedDotProductAttention is a pure PyTorch implementation using standard matmul + softmax + matmul. FlashAttention wraps the flash-attn library (v2 and v3), handling QKV layout preparation via _PrepareQKVForFA and format conversions between BSHD/SBHD/THD layouts. FusedAttention calls the cuDNN-backed fused_attn_fwd/fused_attn_bwd C++ extensions via FusedAttnFunc (a custom autograd Function) for maximum performance. FP8EmulationFunc provides FP8 emulation for testing. Each backend handles Flash Attention version detection, context parallelism integration, CPU offloading hooks, and FP8 quantization of Q/K/V/O tensors.
Usage
Used internally by DotProductAttention to select the optimal attention kernel. Backend selection directly impacts training throughput and memory efficiency, especially for FP8 attention and long-sequence workloads.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/attention/dot_product_attention/backends.py- Lines
- 1--2020
Signature
class FP8EmulationFunc(torch.autograd.Function):
def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): ...
def backward(ctx, grad1, grad2, grad3): ...
class UnfusedDotProductAttention(torch.nn.Module):
def __init__(self, ...): ...
def forward(self, ...): ...
class FlashAttention(torch.nn.Module):
def __init__(self, ...): ...
def forward(self, ...): ...
class FusedAttnFunc(torch.autograd.Function):
def forward(ctx, ...): ...
def backward(ctx, d_out, *_args): ...
class FusedAttention(torch.nn.Module):
def __init__(self, ...): ...
def forward(self, ...): ...
Import
from transformer_engine.pytorch.attention.dot_product_attention.backends import (
UnfusedDotProductAttention,
FlashAttention,
FusedAttention,
FP8EmulationFunc,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| query_layer | torch.Tensor |
Yes | Query tensor in BSHD, SBHD, or THD format |
| key_layer | torch.Tensor |
Yes | Key tensor in matching layout |
| value_layer | torch.Tensor |
Yes | Value tensor in matching layout |
| qkv_layout | str |
Yes | QKV memory layout (e.g., "bshd_bshd_bshd", "sbhd_sbhd_sbhd", "t3hd") |
| attn_mask_type | str |
Yes | Attention mask type (no_mask, padding, causal, causal_bottom_right) |
| attention_dropout | float |
No | Dropout probability for attention weights |
| fp8_quantizer | Quantizer |
No | FP8 quantizer for Q/K/V/O tensors |
| core_attention_bias | torch.Tensor |
No | Optional attention bias tensor |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor |
Attention output tensor in matching layout |
Usage Examples
# The backends are typically used internally by DotProductAttention.
# Direct usage of FlashAttention backend:
from transformer_engine.pytorch.attention.dot_product_attention.backends import FlashAttention
flash_attn = FlashAttention()
output = flash_attn(
query_layer=q,
key_layer=k,
value_layer=v,
qkv_layout="bshd_bshd_bshd",
attn_mask_type="causal",
)