Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL ContextParallelGather

From Leeroopedia


Knowledge Sources
Domains Distributed_Computing, Context_Parallelism
Last Updated 2026-02-07 20:00 GMT

Overview

A differentiable gather operation for context parallelism that reassembles distributed sequence chunks from all ranks in the context parallel group.

Description

This module implements a custom PyTorch autograd function (_ContextParallelGather) and its public wrapper (context_parallel_gather) for reassembling sequence data that has been distributed across context parallel ranks. Context parallelism splits long sequences across multiple GPUs; this gather operation reconstructs the full sequence from all shards.

Forward pass: The function performs an all_gather across the context parallel group to collect all sequence chunks. It then reorders the gathered tensors using a specific interleaving pattern: the first halves of each rank's chunk are concatenated in rank order, followed by the second halves in reverse rank order. This reordering scheme is consistent with Megatron-Core's context parallelism chunking strategy, which splits sequences into alternating halves for causal attention compatibility.

Backward pass: Since the forward gather produces identical outputs across all ranks (given context-parallel-consistent loss computation), the backward pass simply extracts the gradient slice corresponding to the current rank. It takes the chunk at the rank's forward index and the chunk at the reverse position, concatenating them to reconstruct the local gradient.

Usage

Use context_parallel_gather when you need to reconstruct a full sequence tensor from context-parallel shards, typically before operations that require the full sequence (e.g., log probability computation, reward calculation). The backward pass requires that the subsequent loss computation produces identical gradients across all context parallel ranks.

Code Reference

Source Location

Signature

class _ContextParallelGather(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        context_parallel_input: torch.Tensor,
        parallel_dim: int = -1,
    ) -> torch.Tensor: ...

    @staticmethod
    def backward(
        ctx,
        grad_output: torch.Tensor,
    ) -> tuple[torch.Tensor, None]: ...

def context_parallel_gather(
    context_parallel_input: torch.Tensor,
    parallel_dim: int = -1,
) -> torch.Tensor: ...

Import

from mcore_adapter.parallel_functions.context_parallel import context_parallel_gather

I/O Contract

Inputs

Name Type Required Description
context_parallel_input torch.Tensor Yes The local sequence shard tensor from this rank in the context parallel group
parallel_dim int No The dimension along which the sequence is distributed (default -1, i.e., last dimension)

Outputs

Name Type Description
gathered_tensor torch.Tensor The full sequence tensor reconstructed from all context parallel ranks, with size multiplied by world_size along parallel_dim

Usage Examples

from mcore_adapter.parallel_functions.context_parallel import context_parallel_gather

# Each rank holds a shard of the sequence along the last dimension
# e.g., local_logits shape: [batch_size, seq_len // cp_world_size, vocab_size]
local_logits = model(input_ids_shard)

# Gather full sequence logits across context parallel group
# Result shape: [batch_size, seq_len, vocab_size]
full_logits = context_parallel_gather(local_logits, parallel_dim=1)

# Now compute loss over the full sequence
loss = compute_loss(full_logits, labels)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment