Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Transformers DeviceMesh Construction

From Leeroopedia
Knowledge Sources
Domains Distributed_Computing, Training
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete tool for constructing a multi-dimensional DeviceMesh with named parallelism axes provided by PyTorch.

Description

This wrapper constructs a torch.distributed.device_mesh.DeviceMesh by reshaping a flat range of GPU ranks into a 3D tensor with dimensions ("dp", "tp", "cp"). The mesh is then sliced along each named dimension to produce sub-meshes for tensor parallelism, data parallelism, and context parallelism. A flattened "dp_cp" composite mesh is also created for combined gradient synchronization.

The construction involves:

  1. Creating a rank tensor: torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
  2. Wrapping it in a DeviceMesh with named dimensions.
  3. Extracting sub-meshes via indexing: world_mesh["tp"], world_mesh["dp"], world_mesh["cp"].
  4. Flattening the DP and CP dimensions for joint gradient all-reduce.

Usage

Use this pattern immediately after dist.init_process_group() and before loading the model or constructing the data loader. The resulting sub-meshes are passed to downstream components:

  • tp_mesh to AutoModelForCausalLM.from_pretrained(device_mesh=tp_mesh)
  • dp_mesh to FSDP(model, device_mesh=dp_mesh)
  • cp_mesh to context_parallel(cp_mesh, ...)

Code Reference

Source Location

  • Repository: transformers
  • File: examples/3D_parallel.py
  • Lines: 101-107

Signature

DeviceMesh(device_type="cuda", mesh=tensor, mesh_dim_names=("dp", "tp", "cp"))

Import

from torch.distributed.device_mesh import DeviceMesh

I/O Contract

Inputs

Name Type Required Description
device_type str Yes Device type for the mesh, typically "cuda".
mesh torch.Tensor Yes An integer tensor of rank IDs reshaped to (dp_size, tp_size, cp_size).
mesh_dim_names tuple[str, ...] Yes Names for each mesh dimension, e.g. ("dp", "tp", "cp").

Outputs

Name Type Description
world_mesh DeviceMesh The full multi-dimensional device mesh with named axes.
tp_mesh DeviceMesh 1D sub-mesh for tensor parallelism, obtained via world_mesh["tp"].
dp_mesh DeviceMesh 1D sub-mesh for data parallelism, obtained via world_mesh["dp"].
cp_mesh DeviceMesh 1D sub-mesh for context parallelism, obtained via world_mesh["cp"].

Usage Examples

Basic Usage

import torch
from torch.distributed.device_mesh import DeviceMesh

tp_size = 2
dp_size = 2
cp_size = 2
world_size = tp_size * dp_size * cp_size  # 8

mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))

tp_mesh = world_mesh["tp"]
dp_mesh = world_mesh["dp"]
cp_mesh = world_mesh["cp"]

# Create flattened dp_cp mesh for gradient all-reduce
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")

Single-Axis Parallelism

# TP-only: 4 GPUs for tensor parallelism, no DP or CP
mesh = torch.arange(4).reshape(1, 4, 1)
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
tp_mesh = world_mesh["tp"]  # Contains all 4 ranks

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment