Implementation:Huggingface Transformers AutoModelForCausalLM From Pretrained For TP
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Training, Model_Loading |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Concrete API for loading a pretrained causal language model with tensor-parallel weight sharding provided by Hugging Face Transformers.
Description
AutoModelForCausalLM.from_pretrained is extended with tensor parallelism support through the device_mesh and tp_plan parameters. When tp_plan="auto" is specified, the method uses the model's built-in _tp_plan to determine how each layer's weights should be sharded across the TP device mesh.
Internally, the initialize_tensor_parallelism function in src/transformers/integrations/tensor_parallel.py handles the setup:
- If a multi-dimensional device mesh is provided, it extracts the
"tp"sub-mesh. - It determines the TP size from the mesh.
- It maps each device to the correct CUDA device using
LOCAL_RANK. - During weight loading,
shard_and_distribute_moduleuses the TP plan to select the appropriate sharding strategy (ColwiseParallel,RowwiseParallel,PackedColwiseParallel, etc.) for each parameter. - Forward hooks are registered via
add_tensor_parallel_hooks_to_moduleto insert the required communication operations (all-reduce, all-gather) during training.
Usage
Use this API when loading any supported causal language model for tensor-parallel training. The model must define a _tp_plan (most decoder-only LLMs in Transformers do). Pass the TP sub-mesh extracted from the world device mesh.
Code Reference
Source Location
- Repository: transformers
- File:
examples/3D_parallel.py(lines 141-146, usage) - File:
src/transformers/integrations/tensor_parallel.py(lines 40-109, initialization logic)
Signature
AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device_mesh=tp_mesh,
tp_plan="auto",
dtype=torch.bfloat16,
)
Import
from transformers import AutoModelForCausalLM
Key Internal Function
# src/transformers/integrations/tensor_parallel.py
def initialize_tensor_parallelism(
tp_plan: str | dict[str, str] | None,
tp_size: int | None = None,
device_mesh=None,
device_map=None,
) -> tuple[device_map, device_mesh, tp_size]:
...
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| pretrained_model_name_or_path | str | Yes | Model identifier (Hub ID or local path), e.g. "HuggingFaceTB/SmolLM2-1.7B".
|
| device_mesh | DeviceMesh | No | The TP sub-mesh. If multi-dimensional, the "tp" dimension is extracted automatically.
|
| tp_plan | str or dict | No | Set to "auto" to use the model's built-in TP plan, or provide a custom dict mapping layer names to sharding styles.
|
| dtype | torch.dtype | No | Data type for model weights, e.g. torch.bfloat16.
|
Outputs
| Name | Type | Description |
|---|---|---|
| model | PreTrainedModel | The loaded model with weights sharded across TP ranks and communication hooks registered. |
Usage Examples
Basic Usage
import torch
from transformers import AutoModelForCausalLM
# tp_mesh is the TP sub-mesh from the world DeviceMesh
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM2-1.7B",
device_mesh=tp_mesh,
tp_plan="auto",
dtype=torch.bfloat16,
)
With Full 3D Mesh
import torch
from torch.distributed.device_mesh import DeviceMesh
from transformers import AutoModelForCausalLM
# Construct the full mesh
world_mesh = DeviceMesh(
device_type="cuda",
mesh=torch.arange(8).reshape(2, 2, 2),
mesh_dim_names=("dp", "tp", "cp"),
)
tp_mesh = world_mesh["tp"]
# Load model with TP sharding -- only the TP shard is loaded per rank
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM2-1.7B",
device_mesh=tp_mesh,
tp_plan="auto",
dtype=torch.bfloat16,
)