Implementation:Huggingface Transformers Context Parallel Training Loop
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Training, Attention |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Concrete pattern for executing a training step with context parallelism that shards the sequence dimension across CP ranks as used in the Hugging Face Transformers 3D parallel training example.
Description
This pattern implements the forward-backward training step within a context-parallel context manager. The key steps are:
- Position ID construction: Before entering the CP context, position_ids are constructed for the full sequence and broadcast to all batch elements.
- CP context activation: The
context_parallelcontext manager is entered with the CP mesh, the input buffers (input_ids, labels, position_ids), and the sequence dimensions for each buffer. This shards each buffer along its sequence dimension across CP ranks. - Forward pass: Labels are popped from the batch before the model forward pass. The model receives sharded input_ids and position_ids (each of shape
[batch_size, seq_len/cp_size]). The attention mechanism internally performs ring communication to compute full attention. - Loss computation: The loss is computed using the model's
loss_functionwith the sharded logits and labels. - Backward pass:
loss.backward()computes gradients through the CP context, which handles the reverse ring communication for gradient flow.
The context parallel context is conditionally applied: when cp_size == 1, a nullcontext() is used, adding zero overhead. The SDPA backend is set to FLASH_ATTENTION for compatibility with the ring attention pattern.
Usage
Use this pattern for every training step when context parallelism is enabled (cp_size > 1). It must be combined with fixed-length packed sequences to ensure even partitioning across CP ranks. The pattern wraps both the forward and backward passes, so gradient synchronization (all-reduce) should be performed after exiting the CP context.
Code Reference
Source Location
- Repository: transformers
- File:
examples/3D_parallel.py - Lines: 270-294
Signature
context_parallel(
cp_mesh,
buffers=[batch["input_ids"], batch["labels"], batch["position_ids"]],
buffer_seq_dims=[1, 1, 1],
)
Import
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
from contextlib import nullcontext
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cp_mesh | DeviceMesh | Yes | The CP sub-mesh extracted from the world mesh via world_mesh["cp"].
|
| buffers | list[torch.Tensor] | Yes | List of tensors to shard along the sequence dimension: [input_ids, labels, position_ids].
|
| buffer_seq_dims | list[int] | Yes | The sequence dimension index for each buffer. All are 1 since tensors have shape (batch, seq_len).
|
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Scalar loss computed on the local (sharded) sequence portion. |
| logits | torch.Tensor | Model output logits of shape (batch_size, seq_len/cp_size, vocab_size).
|
Usage Examples
Basic Usage
from contextlib import nullcontext
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
# Construct position_ids for the full sequence
batch_size = batch["input_ids"].shape[0]
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
batch["position_ids"] = position_ids
# Conditionally enable context parallelism
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
cp_context = (
nullcontext()
if cp_mesh.size() == 1
else context_parallel(
cp_mesh,
buffers=[
batch["input_ids"],
batch["labels"],
batch["position_ids"],
],
buffer_seq_dims=[1, 1, 1],
)
)
with cp_context:
labels = batch.pop("labels")
outputs = model(**batch)
loss = model.loss_function(
logits=outputs.logits, labels=None,
shift_labels=labels, vocab_size=model.config.vocab_size,
)
loss.backward()