Implementation:Deepspeedai DeepSpeed Tp Model Init
Overview
Concrete tool for recording tensor parallel initialization arguments for AutoTP training provided by the DeepSpeed library.
Implementation Type
Function (top-level API entry point)
Detailed Description
deepspeed.tp_model_init() records TP arguments (tp_size, dtype, tp_group) globally via record_tp_model_init_args() for later validation and merging with the DeepSpeed config during deepspeed.initialize(). It does not perform any actual model sharding -- that happens during engine initialization.
The function performs two actions:
- Records TP arguments: Calls
record_tp_model_init_args(tp_size, dtype, tp_group, dist_module)which stores the arguments in the global_TP_MODEL_INIT_ARGSdictionary. If called multiple times, it validates that the new arguments match the previously recorded ones (mismatches raiseValueError). - Sets AutoTP training mode: Calls
set_autotp_mode(training=True)which setsDEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.TRAINING. This global flag is later checked by TP layer implementations to select training-compatible behavior (e.g., non-inplace bias addition for autograd).
During deepspeed.initialize(), the companion function merge_tp_model_init_into_config() reads _TP_MODEL_INIT_ARGS and:
- Creates or populates the
tensor_parallelsection in the config dict. - Validates that
autotp_size,dtype, andtp_groupdo not conflict between the recorded args and the config. - Auto-creates TP groups via
_init_tp_mesh_device()if neithertp_groupnormpuis available.
Code Reference
- Repository: https://github.com/deepspeedai/DeepSpeed
- File:
deepspeed/__init__.py(L391-454) - Signature:
def tp_model_init(model, tp_size, dtype, config=None, **kwargs) -> torch.nn.Module - Also:
record_tp_model_init_args(tp_size, dtype, tp_group, dist_module)indeepspeed/runtime/tensor_parallel/init_utils.py(L36-60) - Also:
merge_tp_model_init_into_config(config_dict, mpu, mesh_param, dist_module)indeepspeed/runtime/tensor_parallel/init_utils.py(L79-147) - Import:
import deepspeed
Parameters
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
| model | torch.nn.Module | Yes | — | The model to be initialized (returned unmodified) |
| tp_size | int | Yes | — | The tensor parallelism degree (number of GPUs for TP) |
| dtype | torch.dtype | Yes | — | The data type for the model (e.g., torch.bfloat16, torch.float16) |
| config | dict | No | None | Optional DeepSpeed configuration dictionary |
| **kwargs | keyword arguments | No | — | Additional arguments; tp_group is extracted if present
|
I/O
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | model | torch.nn.Module | The pretrained model (e.g., from HuggingFace) |
| Input | tp_size | int | Tensor parallelism degree |
| Input | dtype | torch.dtype | Model precision (fp16 or bf16) |
| Input | config | dict | Optional DeepSpeed config |
| Output | model | torch.nn.Module | The same model, unmodified; TP init args recorded globally for deferred sharding |
Usage Example
import deepspeed
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = deepspeed.tp_model_init(model, tp_size=4, dtype=torch.bfloat16)
# Model is NOT sharded yet -- sharding happens in deepspeed.initialize()
engine, _, _, _ = deepspeed.initialize(
model=model,
config={"train_batch_size": 8, "bf16": {"enabled": True}}
)
# Now the model is TP-partitioned inside the engine
Knowledge Sources
Relationships
Principle:Deepspeedai_DeepSpeed_AutoTP_Model_Loading
Metadata
- Workflow: AutoTP_Training
- Type: Implementation
- Last Updated: 2026-02-09 00:00 GMT