Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Transformers Context Parallel Training Loop

From Leeroopedia
Knowledge Sources
Domains Distributed_Computing, Training, Attention
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete pattern for executing a training step with context parallelism that shards the sequence dimension across CP ranks as used in the Hugging Face Transformers 3D parallel training example.

Description

This pattern implements the forward-backward training step within a context-parallel context manager. The key steps are:

  1. Position ID construction: Before entering the CP context, position_ids are constructed for the full sequence and broadcast to all batch elements.
  2. CP context activation: The context_parallel context manager is entered with the CP mesh, the input buffers (input_ids, labels, position_ids), and the sequence dimensions for each buffer. This shards each buffer along its sequence dimension across CP ranks.
  3. Forward pass: Labels are popped from the batch before the model forward pass. The model receives sharded input_ids and position_ids (each of shape [batch_size, seq_len/cp_size]). The attention mechanism internally performs ring communication to compute full attention.
  4. Loss computation: The loss is computed using the model's loss_function with the sharded logits and labels.
  5. Backward pass: loss.backward() computes gradients through the CP context, which handles the reverse ring communication for gradient flow.

The context parallel context is conditionally applied: when cp_size == 1, a nullcontext() is used, adding zero overhead. The SDPA backend is set to FLASH_ATTENTION for compatibility with the ring attention pattern.

Usage

Use this pattern for every training step when context parallelism is enabled (cp_size > 1). It must be combined with fixed-length packed sequences to ensure even partitioning across CP ranks. The pattern wraps both the forward and backward passes, so gradient synchronization (all-reduce) should be performed after exiting the CP context.

Code Reference

Source Location

  • Repository: transformers
  • File: examples/3D_parallel.py
  • Lines: 270-294

Signature

context_parallel(
    cp_mesh,
    buffers=[batch["input_ids"], batch["labels"], batch["position_ids"]],
    buffer_seq_dims=[1, 1, 1],
)

Import

from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
from contextlib import nullcontext

I/O Contract

Inputs

Name Type Required Description
cp_mesh DeviceMesh Yes The CP sub-mesh extracted from the world mesh via world_mesh["cp"].
buffers list[torch.Tensor] Yes List of tensors to shard along the sequence dimension: [input_ids, labels, position_ids].
buffer_seq_dims list[int] Yes The sequence dimension index for each buffer. All are 1 since tensors have shape (batch, seq_len).

Outputs

Name Type Description
loss torch.Tensor Scalar loss computed on the local (sharded) sequence portion.
logits torch.Tensor Model output logits of shape (batch_size, seq_len/cp_size, vocab_size).

Usage Examples

Basic Usage

from contextlib import nullcontext
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel

# Construct position_ids for the full sequence
batch_size = batch["input_ids"].shape[0]
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
batch["position_ids"] = position_ids

# Conditionally enable context parallelism
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    cp_context = (
        nullcontext()
        if cp_mesh.size() == 1
        else context_parallel(
            cp_mesh,
            buffers=[
                batch["input_ids"],
                batch["labels"],
                batch["position_ids"],
            ],
            buffer_seq_dims=[1, 1, 1],
        )
    )
    with cp_context:
        labels = batch.pop("labels")
        outputs = model(**batch)
        loss = model.loss_function(
            logits=outputs.logits, labels=None,
            shift_labels=labels, vocab_size=model.config.vocab_size,
        )
        loss.backward()

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment