Implementation:Alibaba ROLL ContextParallelGather
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Context_Parallelism |
| Last Updated | 2026-02-07 20:00 GMT |
Overview
A differentiable gather operation for context parallelism that reassembles distributed sequence chunks from all ranks in the context parallel group.
Description
This module implements a custom PyTorch autograd function (_ContextParallelGather) and its public wrapper (context_parallel_gather) for reassembling sequence data that has been distributed across context parallel ranks. Context parallelism splits long sequences across multiple GPUs; this gather operation reconstructs the full sequence from all shards.
Forward pass: The function performs an all_gather across the context parallel group to collect all sequence chunks. It then reorders the gathered tensors using a specific interleaving pattern: the first halves of each rank's chunk are concatenated in rank order, followed by the second halves in reverse rank order. This reordering scheme is consistent with Megatron-Core's context parallelism chunking strategy, which splits sequences into alternating halves for causal attention compatibility.
Backward pass: Since the forward gather produces identical outputs across all ranks (given context-parallel-consistent loss computation), the backward pass simply extracts the gradient slice corresponding to the current rank. It takes the chunk at the rank's forward index and the chunk at the reverse position, concatenating them to reconstruct the local gradient.
Usage
Use context_parallel_gather when you need to reconstruct a full sequence tensor from context-parallel shards, typically before operations that require the full sequence (e.g., log probability computation, reward calculation). The backward pass requires that the subsequent loss computation produces identical gradients across all context parallel ranks.
Code Reference
Source Location
- Repository: Alibaba_ROLL
- File: mcore_adapter/src/mcore_adapter/parallel_functions/context_parallel.py
- Lines: 1-35
Signature
class _ContextParallelGather(torch.autograd.Function):
@staticmethod
def forward(
ctx,
context_parallel_input: torch.Tensor,
parallel_dim: int = -1,
) -> torch.Tensor: ...
@staticmethod
def backward(
ctx,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, None]: ...
def context_parallel_gather(
context_parallel_input: torch.Tensor,
parallel_dim: int = -1,
) -> torch.Tensor: ...
Import
from mcore_adapter.parallel_functions.context_parallel import context_parallel_gather
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| context_parallel_input | torch.Tensor | Yes | The local sequence shard tensor from this rank in the context parallel group |
| parallel_dim | int | No | The dimension along which the sequence is distributed (default -1, i.e., last dimension) |
Outputs
| Name | Type | Description |
|---|---|---|
| gathered_tensor | torch.Tensor | The full sequence tensor reconstructed from all context parallel ranks, with size multiplied by world_size along parallel_dim |
Usage Examples
from mcore_adapter.parallel_functions.context_parallel import context_parallel_gather
# Each rank holds a shard of the sequence along the last dimension
# e.g., local_logits shape: [batch_size, seq_len // cp_world_size, vocab_size]
local_logits = model(input_ids_shard)
# Gather full sequence logits across context parallel group
# Result shape: [batch_size, seq_len, vocab_size]
full_logits = context_parallel_gather(local_logits, parallel_dim=1)
# Now compute loss over the full sequence
loss = compute_loss(full_logits, labels)