Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Attention Backends

From Leeroopedia
Revision as of 15:57, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/NVIDIA_TransformerEngine_Attention_Backends.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Attention
Last Updated 2026-02-07 14:00 GMT

Overview

Implements the concrete attention backend strategies (unfused PyTorch, Flash Attention v2/v3, and cuDNN Fused Attention) with FP8 support for the TransformerEngine dot-product attention dispatcher.

Description

This module provides three backend classes that DotProductAttention dispatches to at runtime. UnfusedDotProductAttention is a pure PyTorch implementation using standard matmul + softmax + matmul. FlashAttention wraps the flash-attn library (v2 and v3), handling QKV layout preparation via _PrepareQKVForFA and format conversions between BSHD/SBHD/THD layouts. FusedAttention calls the cuDNN-backed fused_attn_fwd/fused_attn_bwd C++ extensions via FusedAttnFunc (a custom autograd Function) for maximum performance. FP8EmulationFunc provides FP8 emulation for testing. Each backend handles Flash Attention version detection, context parallelism integration, CPU offloading hooks, and FP8 quantization of Q/K/V/O tensors.

Usage

Used internally by DotProductAttention to select the optimal attention kernel. Backend selection directly impacts training throughput and memory efficiency, especially for FP8 attention and long-sequence workloads.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Lines
1--2020

Signature

class FP8EmulationFunc(torch.autograd.Function):
    def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): ...
    def backward(ctx, grad1, grad2, grad3): ...

class UnfusedDotProductAttention(torch.nn.Module):
    def __init__(self, ...): ...
    def forward(self, ...): ...

class FlashAttention(torch.nn.Module):
    def __init__(self, ...): ...
    def forward(self, ...): ...

class FusedAttnFunc(torch.autograd.Function):
    def forward(ctx, ...): ...
    def backward(ctx, d_out, *_args): ...

class FusedAttention(torch.nn.Module):
    def __init__(self, ...): ...
    def forward(self, ...): ...

Import

from transformer_engine.pytorch.attention.dot_product_attention.backends import (
    UnfusedDotProductAttention,
    FlashAttention,
    FusedAttention,
    FP8EmulationFunc,
)

I/O Contract

Inputs

Name Type Required Description
query_layer torch.Tensor Yes Query tensor in BSHD, SBHD, or THD format
key_layer torch.Tensor Yes Key tensor in matching layout
value_layer torch.Tensor Yes Value tensor in matching layout
qkv_layout str Yes QKV memory layout (e.g., "bshd_bshd_bshd", "sbhd_sbhd_sbhd", "t3hd")
attn_mask_type str Yes Attention mask type (no_mask, padding, causal, causal_bottom_right)
attention_dropout float No Dropout probability for attention weights
fp8_quantizer Quantizer No FP8 quantizer for Q/K/V/O tensors
core_attention_bias torch.Tensor No Optional attention bias tensor

Outputs

Name Type Description
output torch.Tensor Attention output tensor in matching layout

Usage Examples

# The backends are typically used internally by DotProductAttention.
# Direct usage of FlashAttention backend:
from transformer_engine.pytorch.attention.dot_product_attention.backends import FlashAttention

flash_attn = FlashAttention()
output = flash_attn(
    query_layer=q,
    key_layer=k,
    value_layer=v,
    qkv_layout="bshd_bshd_bshd",
    attn_mask_type="causal",
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment