Implementation:NVIDIA TransformerEngine Context Parallel
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Attention, Distributed |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements context parallelism (CP) for attention, enabling long sequences to be split across multiple GPUs with three communication strategies: P2P ring, all-gather, and all-to-all.
Description
Provides three custom autograd Functions implementing different CP communication strategies. AttnFuncWithCPAndKVP2P uses point-to-point ring communication to pass KV chunks between ranks. AttnFuncWithCPAndKVAllGather uses all-gather to collect full KV tensors. AttnFuncWithCPAndQKVOA2A uses all-to-all to redistribute Q, K, V, and O across ranks. JIT-fused helper functions (flash_attn_fwd_out_correction, flash_attn_fwd_softmax_lse_correction, etc.) handle the online softmax correction required to merge partial attention outputs from different sequence chunks. Supports both causal and non-causal masks, various QKV layouts, and FP8 quantized attention within CP.
Usage
Used when training with very long sequences that exceed single-GPU memory. Context parallelism distributes the sequence dimension across GPUs, enabling context lengths in the hundreds of thousands or millions of tokens.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py- Lines
- 1--4321
Signature
def flash_attn_p2p_communicate(rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm): ...
def get_cu_seqlens_on_cp_rank(cu_seqlens, cp_rank, cp_size): ...
def attn_forward_func_with_cp(...): ...
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
def forward(ctx, ...): ...
def backward(ctx, ...): ...
class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
def forward(ctx, ...): ...
def backward(ctx, ...): ...
class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
def forward(ctx, ...): ...
def backward(ctx, ...): ...
Import
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
attn_forward_func_with_cp,
AttnFuncWithCPAndKVP2P,
AttnFuncWithCPAndKVAllGather,
AttnFuncWithCPAndQKVOA2A,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| query_layer | torch.Tensor |
Yes | Query tensor for the local CP rank |
| key_layer | torch.Tensor |
Yes | Key tensor for the local CP rank |
| value_layer | torch.Tensor |
Yes | Value tensor for the local CP rank |
| cp_group | dist_group_type |
Yes | Process group for context parallelism communication |
| cp_size | int |
Yes | Number of CP ranks |
| cp_stream | torch.cuda.Stream |
Yes | CUDA stream for overlapping communication |
| cp_comm_type | str |
Yes | Communication type: "p2p", "all_gather", or "a2a" |
| fused_attn_backend | FusedAttnBackend |
No | Backend for fused attention within CP |
| cu_seqlens_q | torch.Tensor |
No | Cumulative sequence lengths for THD format |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor |
Combined attention output from all CP ranks |
| softmax_lse | torch.Tensor |
Log-sum-exp of softmax for backward pass correction |
Usage Examples
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
attn_forward_func_with_cp,
)
# Typically invoked internally by attention backends when CP is enabled
output = attn_forward_func_with_cp(
query_layer=q,
key_layer=k,
value_layer=v,
cp_group=cp_group,
cp_size=4,
cp_stream=cp_stream,
cp_comm_type="p2p",
)