Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Huggingface Transformers Context Parallelism

From Leeroopedia
Knowledge Sources
Domains Distributed_Computing, Training, Attention
Last Updated 2026-02-13 00:00 GMT

Overview

Context parallelism splits the sequence dimension across multiple devices, enabling training on sequences longer than what a single device can handle by distributing the attention computation.

Description

Context Parallelism (CP) addresses the memory and compute bottleneck of self-attention in transformer models, which scales quadratically with sequence length. By partitioning the sequence dimension across multiple GPUs, each device processes only a fraction of the full sequence (i.e., seq_len / cp_size tokens), dramatically reducing per-device memory requirements for the attention computation.

The key mechanism is that input tensors (input_ids, labels, position_ids) are sharded along the sequence dimension before entering the model. During the attention computation, each device computes attention over its local sequence chunk but uses a ring-based communication pattern to exchange key-value pairs with other CP ranks, ensuring that every token can attend to every other token in the full sequence. This is conceptually similar to Ring Attention (Liu et al., 2023).

In PyTorch, context parallelism is provided through torch.distributed.tensor.experimental.context_parallel, which acts as a context manager that:

  1. Shards specified input buffers along specified sequence dimensions across the CP mesh.
  2. Wraps the attention computation to perform ring communication of KV blocks.
  3. Collects outputs back to the expected shape for the loss computation.

CP is orthogonal to tensor parallelism (which shards along the feature/head dimension) and data parallelism (which shards along the batch dimension). Together, they form the three axes of 3D parallelism:

  • DP: Shards batch dimension.
  • TP: Shards feature/head dimension.
  • CP: Shards sequence dimension.

Usage

Use context parallelism when:

  • Training with very long sequences that exceed single-GPU memory for attention computation.
  • The model uses standard self-attention (SDPA) and supports ring attention patterns.
  • You need to scale sequence length beyond what tensor parallelism alone can handle.
  • All sequences in a batch have the same length (required for even partitioning), which is naturally achieved through sequence packing.

CP is typically activated only when cp_size > 1. When cp_size == 1, a nullcontext() is used to bypass the CP overhead.

Theoretical Basis

Context parallelism is grounded in Ring Attention (Liu et al., 2023), which extends the standard self-attention computation to a distributed setting:

  1. Each CP rank holds a contiguous chunk of the query, key, and value tensors.
  2. Attention is computed in rounds. In each round, each rank computes partial attention scores between its local queries and the keys/values it currently holds.
  3. After each round, key-value blocks are rotated to the next rank in a ring topology.
  4. After cp_size rounds, every query has attended to every key-value pair.
  5. The partial softmax results are combined using the online softmax trick (numerically stable incremental softmax).

The communication cost of CP is O(cp_size * seq_len/cp_size * d_model) per attention layer, which is the same total data volume as a single all-gather of the KV tensors but pipelined across rounds to overlap with computation.

The memory benefit is direct: attention activations scale as O((seq_len/cp_size)^2) per device instead of O(seq_len^2), a reduction of cp_size^2 in the quadratic attention memory.

CP requires a compatible SDPA (Scaled Dot Product Attention) backend. In the 3D parallel example, SDPBackend.FLASH_ATTENTION is used for efficiency.

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment