Implementation:Huggingface Transformers DeviceMesh Construction
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Training |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Concrete tool for constructing a multi-dimensional DeviceMesh with named parallelism axes provided by PyTorch.
Description
This wrapper constructs a torch.distributed.device_mesh.DeviceMesh by reshaping a flat range of GPU ranks into a 3D tensor with dimensions ("dp", "tp", "cp"). The mesh is then sliced along each named dimension to produce sub-meshes for tensor parallelism, data parallelism, and context parallelism. A flattened "dp_cp" composite mesh is also created for combined gradient synchronization.
The construction involves:
- Creating a rank tensor:
torch.arange(world_size).reshape(dp_size, tp_size, cp_size) - Wrapping it in a DeviceMesh with named dimensions.
- Extracting sub-meshes via indexing:
world_mesh["tp"],world_mesh["dp"],world_mesh["cp"]. - Flattening the DP and CP dimensions for joint gradient all-reduce.
Usage
Use this pattern immediately after dist.init_process_group() and before loading the model or constructing the data loader. The resulting sub-meshes are passed to downstream components:
tp_meshtoAutoModelForCausalLM.from_pretrained(device_mesh=tp_mesh)dp_meshtoFSDP(model, device_mesh=dp_mesh)cp_meshtocontext_parallel(cp_mesh, ...)
Code Reference
Source Location
- Repository: transformers
- File:
examples/3D_parallel.py - Lines: 101-107
Signature
DeviceMesh(device_type="cuda", mesh=tensor, mesh_dim_names=("dp", "tp", "cp"))
Import
from torch.distributed.device_mesh import DeviceMesh
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| device_type | str | Yes | Device type for the mesh, typically "cuda".
|
| mesh | torch.Tensor | Yes | An integer tensor of rank IDs reshaped to (dp_size, tp_size, cp_size).
|
| mesh_dim_names | tuple[str, ...] | Yes | Names for each mesh dimension, e.g. ("dp", "tp", "cp").
|
Outputs
| Name | Type | Description |
|---|---|---|
| world_mesh | DeviceMesh | The full multi-dimensional device mesh with named axes. |
| tp_mesh | DeviceMesh | 1D sub-mesh for tensor parallelism, obtained via world_mesh["tp"].
|
| dp_mesh | DeviceMesh | 1D sub-mesh for data parallelism, obtained via world_mesh["dp"].
|
| cp_mesh | DeviceMesh | 1D sub-mesh for context parallelism, obtained via world_mesh["cp"].
|
Usage Examples
Basic Usage
import torch
from torch.distributed.device_mesh import DeviceMesh
tp_size = 2
dp_size = 2
cp_size = 2
world_size = tp_size * dp_size * cp_size # 8
mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
tp_mesh = world_mesh["tp"]
dp_mesh = world_mesh["dp"]
cp_mesh = world_mesh["cp"]
# Create flattened dp_cp mesh for gradient all-reduce
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
Single-Axis Parallelism
# TP-only: 4 GPUs for tensor parallelism, no DP or CP
mesh = torch.arange(4).reshape(1, 4, 1)
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
tp_mesh = world_mesh["tp"] # Contains all 4 ranks