Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Alibaba ROLL Context Parallel Communication

From Leeroopedia


Knowledge Sources
Domains Distributed_Computing, Context_Parallelism
Last Updated 2026-02-07 20:00 GMT

Overview

A differentiable communication primitive that reassembles a full sequence from context-parallel shards with correct ordering, enabling gradient flow through the gather operation.

Description

Context parallelism distributes the sequence dimension across multiple GPUs, allowing each GPU to process only a fraction of the full sequence during attention computation. However, certain operations (such as computing log-probabilities over the full vocabulary for policy gradient methods) require access to the complete sequence on each rank.

The challenge is that simply splitting a sequence into C contiguous chunks and assigning each to a different GPU creates severe load imbalance in causal attention: early chunks attend to few tokens while later chunks attend to many. The standard solution is to split the sequence into 2C chunks and assign each GPU a pair of chunks from opposite ends of the sequence, balancing the attention workload.

This means the data on each GPU is not contiguous in the original sequence ordering. Reassembling the full sequence requires:

  1. All-gather: Collect all local chunks from every GPU in the context-parallel group.
  2. Reorder: Each gathered tensor contains two sub-chunks. The first sub-chunks are ordered by rank (0, 1, ..., C-1) and the second sub-chunks are ordered in reverse (C-1, ..., 1, 0). These must be concatenated in the correct interleaved order.
  3. Backward pass: The gradient of the gather operation is a scatter that extracts the appropriate sub-chunks for each rank, preserving the paired ordering.

This operation is implemented as a custom autograd function to ensure correct gradient propagation through the communication boundary.

Usage

Use this principle when:

  • You need to compute a function over the full sequence (e.g., vocabulary log-probabilities) but the sequence is distributed across context-parallel ranks using the zigzag partition scheme.
  • The gather operation must be differentiable so that gradients from the full-sequence computation can flow back to the individual context-parallel shards.

Theoretical Basis

Context-parallel zigzag partitioning:

For a sequence of length S split across C GPUs:

chunks = split(sequence, 2*C)  # 2C chunks of size S/(2C)
GPU_k receives: [chunks[k], chunks[2C - k - 1]]

Forward (gather):

gathered = all_gather(local_tensor, group=cp_group)  # C tensors
FOR each tensor t in gathered:
    first_half[rank], second_half[rank] = split(t, 2, dim=parallel_dim)
ordered = [first_half[0], ..., first_half[C-1],
           second_half[C-1], ..., second_half[0]]
result = concatenate(ordered, dim=parallel_dim)

This produces the full sequence in the original order.

Backward (scatter):

Given gradient out over the full sequence:

grad_chunks = split(grad_output, 2*C, dim=parallel_dim)
grad_local = concatenate([grad_chunks[rank], grad_chunks[2*C - rank - 1]],
                         dim=parallel_dim)

Each rank extracts its two paired chunks from the gradient, which is the adjoint of the gather operation.

Correctness property:

The gather-scatter pair satisfies:

scatter(gather(xk))=xkk{0,,C1}

ensuring that the backward pass correctly routes gradients to the rank that produced each portion of the output.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment