Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Transformers AutoModelForCausalLM From Pretrained For TP

From Leeroopedia
Knowledge Sources
Domains Distributed_Computing, Training, Model_Loading
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete API for loading a pretrained causal language model with tensor-parallel weight sharding provided by Hugging Face Transformers.

Description

AutoModelForCausalLM.from_pretrained is extended with tensor parallelism support through the device_mesh and tp_plan parameters. When tp_plan="auto" is specified, the method uses the model's built-in _tp_plan to determine how each layer's weights should be sharded across the TP device mesh.

Internally, the initialize_tensor_parallelism function in src/transformers/integrations/tensor_parallel.py handles the setup:

  1. If a multi-dimensional device mesh is provided, it extracts the "tp" sub-mesh.
  2. It determines the TP size from the mesh.
  3. It maps each device to the correct CUDA device using LOCAL_RANK.
  4. During weight loading, shard_and_distribute_module uses the TP plan to select the appropriate sharding strategy (ColwiseParallel, RowwiseParallel, PackedColwiseParallel, etc.) for each parameter.
  5. Forward hooks are registered via add_tensor_parallel_hooks_to_module to insert the required communication operations (all-reduce, all-gather) during training.

Usage

Use this API when loading any supported causal language model for tensor-parallel training. The model must define a _tp_plan (most decoder-only LLMs in Transformers do). Pass the TP sub-mesh extracted from the world device mesh.

Code Reference

Source Location

  • Repository: transformers
  • File: examples/3D_parallel.py (lines 141-146, usage)
  • File: src/transformers/integrations/tensor_parallel.py (lines 40-109, initialization logic)

Signature

AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path,
    device_mesh=tp_mesh,
    tp_plan="auto",
    dtype=torch.bfloat16,
)

Import

from transformers import AutoModelForCausalLM

Key Internal Function

# src/transformers/integrations/tensor_parallel.py
def initialize_tensor_parallelism(
    tp_plan: str | dict[str, str] | None,
    tp_size: int | None = None,
    device_mesh=None,
    device_map=None,
) -> tuple[device_map, device_mesh, tp_size]:
    ...

I/O Contract

Inputs

Name Type Required Description
pretrained_model_name_or_path str Yes Model identifier (Hub ID or local path), e.g. "HuggingFaceTB/SmolLM2-1.7B".
device_mesh DeviceMesh No The TP sub-mesh. If multi-dimensional, the "tp" dimension is extracted automatically.
tp_plan str or dict No Set to "auto" to use the model's built-in TP plan, or provide a custom dict mapping layer names to sharding styles.
dtype torch.dtype No Data type for model weights, e.g. torch.bfloat16.

Outputs

Name Type Description
model PreTrainedModel The loaded model with weights sharded across TP ranks and communication hooks registered.

Usage Examples

Basic Usage

import torch
from transformers import AutoModelForCausalLM

# tp_mesh is the TP sub-mesh from the world DeviceMesh
model = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceTB/SmolLM2-1.7B",
    device_mesh=tp_mesh,
    tp_plan="auto",
    dtype=torch.bfloat16,
)

With Full 3D Mesh

import torch
from torch.distributed.device_mesh import DeviceMesh
from transformers import AutoModelForCausalLM

# Construct the full mesh
world_mesh = DeviceMesh(
    device_type="cuda",
    mesh=torch.arange(8).reshape(2, 2, 2),
    mesh_dim_names=("dp", "tp", "cp"),
)
tp_mesh = world_mesh["tp"]

# Load model with TP sharding -- only the TP shard is loaded per rank
model = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceTB/SmolLM2-1.7B",
    device_mesh=tp_mesh,
    tp_plan="auto",
    dtype=torch.bfloat16,
)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment