Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:Huggingface Optimum Automatic Tensor Parallelization

From Leeroopedia
Knowledge Sources
Domains Distributed_Computing, Model_Parallelism, LLMs
Last Updated 2026-02-15 00:00 GMT

Overview

End-to-end process for automatically distributing a Transformer model across multiple GPUs using tensor parallelism via PyTorch FX graph analysis and Megatron-style layer replacement.

Description

This workflow describes the automatic tensor parallelization system provided by parallelize_model(). The system uses PyTorch FX to trace the model's computational graph, decomposes high-level operations into core ATen operators, and then performs a backtracking search to find a valid parallelization strategy. Once a solution is found, linear layers and embeddings are replaced with their parallel counterparts (ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding) and weights are distributed across processes. The parallelization is integrated with torch.compile for optimized execution.

Key aspects:

  • Fully automatic: no model-specific parallelization code required
  • Uses FX graph tracing and ATen operator decomposition for generality
  • Backtracking search finds valid parallel axis assignments respecting operator constraints
  • Megatron-style parallel layers: column-parallel, row-parallel linear, and vocab-parallel embedding
  • Integrates with torch.compile as a custom backend
  • Supports sharded cross-entropy loss for memory-efficient training

Usage

Execute this workflow when you need to run inference or training on a Transformer model that is too large for a single GPU, and you want to distribute it across multiple GPUs within a single node using tensor parallelism. This is applicable when you have a multi-GPU setup with fast inter-GPU communication (e.g., NVLink) and want to reduce per-GPU memory usage while maintaining model fidelity.

Execution Steps

Step 1: Model Download and Configuration

Download the model weights and configuration from the Hugging Face Hub (or load from a local directory). The model's architecture class is resolved from the config's architectures field via the transformers library. A weight map is collected to enable later sharded weight loading.

Key considerations:

  • The model identifier can be a Hub ID or local directory path
  • Revision, cache directory, and local-only options control the download behavior
  • The weight map tracks which parameters are stored in which checkpoint files
  • The skip_load_weights option allows initializing without loading weights (useful for profiling)

Step 2: Meta-device Model Initialization

Instantiate the model on a meta device (no actual memory allocation) using the MetaAwareMethodsPatcher context manager. This creates the model's module hierarchy and parameter shapes without allocating GPU memory, enabling the parallelization passes to analyze the model structure before committing resources.

Key considerations:

  • The MetaAwareMethodsPatcher patches PyTorch methods to work with meta tensors
  • The model is placed in eval mode (training-time trace not yet supported)
  • torch_dtype can be specified to control the default parameter dtype
  • The model is then moved to the current process's device

Step 3: Parameter Meta Initialization

Initialize metadata for each model parameter that tracks its parallelization state. The ParameterMeta dataclass stores information about each parameter's original shape, the mapping from local slices to global positions, and whether the parameter needs special handling during weight loading.

Key considerations:

  • Each parameter receives a ParameterMeta annotation
  • The metadata tracks the relationship between local (per-device) and global (full-model) parameter shapes
  • This metadata is used later during weight loading to extract the correct shard

Step 4: Parallel Axis Solving (FX Graph Analysis)

Trace the model using torch.compile's Dynamo frontend to produce an FX graph. The graph is decomposed from high-level PyTorch operations into core ATen operators (a much smaller set). A backtracking search algorithm then finds valid parallel axis assignments for every tensor in the graph, respecting the propagation rules defined for each ATen operator.

What happens:

  • High-level ops are decomposed into core ATen ops via operator decomposition
  • Inplace operations are functionalized (e.g., aten.Add_ becomes aten.Add)
  • Each ATen op has registered propagation rules defining how parallel axes flow through it
  • The search prioritizes parallelization on the head dimension (tensor parallel) or sequence dimension (sequence parallel)
  • Axis switching is only allowed around specific layers (nn.Linear, nn.Embedding)

Step 5: Parallel Layer Annotation

Annotate the original (non-decomposed) graph nodes with the parallelization solution found in the previous step. The solution from the decomposed graph is traced back to the original graph's nodes, marking each layer with its parallel axis assignment. This annotation identifies which layers should be replaced with parallel variants.

Key considerations:

  • The trace-back maps decomposed node solutions to original graph nodes
  • Linear layers are annotated as column-parallel or row-parallel based on their axis assignment
  • Embedding layers are annotated for vocabulary parallelism
  • Cross-entropy loss nodes are identified for sharded replacement

Step 6: Parallel Layer Replacement

Replace the annotated layers with their Megatron-style parallel counterparts. Column-parallel linear layers split the output dimension across devices, while row-parallel linear layers split the input dimension. Vocabulary-parallel embeddings partition the vocabulary across devices. Hard-coded shape attributes in the graph are updated to reflect the parallelized dimensions.

Key considerations:

  • ColumnParallelLinear: splits output features, each device computes a portion of the output
  • RowParallelLinear: splits input features, requires all-reduce to combine partial results
  • VocabParallelEmbedding: partitions vocabulary across devices
  • VocabParallelCrossEntropyLoss: computes loss without materializing the full logit tensor
  • Differentiable collective operations (all-reduce, all-gather, scatter) enable gradient flow

Step 7: Sharded Weight Loading

Load the appropriate weight shards for each process from the checkpoint files. Using the weight map and parameter metadata, each process loads only the slice of each parameter that corresponds to its rank in the parallel group. This avoids loading the full model on each device.

Key considerations:

  • The weight map identifies which checkpoint file contains each parameter
  • ParameterMeta's slice mapping determines which portion each rank receives
  • Parameters that are not parallelized are loaded in full on each device
  • The loading respects the parallel group's rank assignment

Execution Diagram

GitHub URL

Workflow Repository