Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Context Parallel

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Attention, Distributed
Last Updated 2026-02-07 14:00 GMT

Overview

Implements context parallelism (CP) for attention, enabling long sequences to be split across multiple GPUs with three communication strategies: P2P ring, all-gather, and all-to-all.

Description

Provides three custom autograd Functions implementing different CP communication strategies. AttnFuncWithCPAndKVP2P uses point-to-point ring communication to pass KV chunks between ranks. AttnFuncWithCPAndKVAllGather uses all-gather to collect full KV tensors. AttnFuncWithCPAndQKVOA2A uses all-to-all to redistribute Q, K, V, and O across ranks. JIT-fused helper functions (flash_attn_fwd_out_correction, flash_attn_fwd_softmax_lse_correction, etc.) handle the online softmax correction required to merge partial attention outputs from different sequence chunks. Supports both causal and non-causal masks, various QKV layouts, and FP8 quantized attention within CP.

Usage

Used when training with very long sequences that exceed single-GPU memory. Context parallelism distributes the sequence dimension across GPUs, enabling context lengths in the hundreds of thousands or millions of tokens.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Lines
1--4321

Signature

def flash_attn_p2p_communicate(rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm): ...
def get_cu_seqlens_on_cp_rank(cu_seqlens, cp_rank, cp_size): ...
def attn_forward_func_with_cp(...): ...

class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
    def forward(ctx, ...): ...
    def backward(ctx, ...): ...

class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
    def forward(ctx, ...): ...
    def backward(ctx, ...): ...

class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
    def forward(ctx, ...): ...
    def backward(ctx, ...): ...

Import

from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
    attn_forward_func_with_cp,
    AttnFuncWithCPAndKVP2P,
    AttnFuncWithCPAndKVAllGather,
    AttnFuncWithCPAndQKVOA2A,
)

I/O Contract

Inputs

Name Type Required Description
query_layer torch.Tensor Yes Query tensor for the local CP rank
key_layer torch.Tensor Yes Key tensor for the local CP rank
value_layer torch.Tensor Yes Value tensor for the local CP rank
cp_group dist_group_type Yes Process group for context parallelism communication
cp_size int Yes Number of CP ranks
cp_stream torch.cuda.Stream Yes CUDA stream for overlapping communication
cp_comm_type str Yes Communication type: "p2p", "all_gather", or "a2a"
fused_attn_backend FusedAttnBackend No Backend for fused attention within CP
cu_seqlens_q torch.Tensor No Cumulative sequence lengths for THD format

Outputs

Name Type Description
output torch.Tensor Combined attention output from all CP ranks
softmax_lse torch.Tensor Log-sum-exp of softmax for backward pass correction

Usage Examples

from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
    attn_forward_func_with_cp,
)

# Typically invoked internally by attention backends when CP is enabled
output = attn_forward_func_with_cp(
    query_layer=q,
    key_layer=k,
    value_layer=v,
    cp_group=cp_group,
    cp_size=4,
    cp_stream=cp_stream,
    cp_comm_type="p2p",
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment