Overview
build_distributed_state_dict_from_consolidated is the main API function in the checkpoint converter module for tensor-parallel Llama models. It converts fairscale consolidated checkpoints into PyTorch Distributed (PT-D) compliant sharded state dictionaries suitable for loading into FSDP-wrapped models. The module provides helper functions for cross-rank verification, all-gather operations, tensor parallel shard detection, and row/column-wise unshard reconstruction.
Description
The checkpoint converter module bridges the gap between Meta's fairscale-based Llama checkpoints (consolidated format) and PyTorch's native distributed tensor abstractions (ShardedTensor and DTensor). This is essential for loading pre-trained Llama weights into a model wrapped with PyTorch FSDP (Fully Sharded Data Parallel).
Key Responsibilities
- FQN Verification:
_verify_fqn_across_ranks() uses dist.all_gather_object() to ensure all ranks process the same fully qualified parameter name
- All-Gather:
_all_gather_into_list() gathers tensor shards from all ranks in a model parallel group onto GPU
- Shard Detection:
_is_tp_sharded() determines if a parameter is tensor-parallel sharded by inspecting the FQN for attention, feed_forward, output, or tok_embeddings substrings
- Unshard Reconstruction:
_unshard_param() reconstructs full tensors from row-wise or column-wise shards using torch.vstack or torch.column_stack respectively
- State Dict Building:
build_distributed_state_dict_from_consolidated() iterates over the consolidated state dict, unshards tensor-parallel parameters, and chunks them into ShardedTensor or DTensor format for FSDP
Usage
from checkpoint_converter import build_distributed_state_dict_from_consolidated
# Convert a fairscale consolidated checkpoint to PT-D sharded format
MODEL_PARALLEL_SIZE = 8
ckpt_path = get_consolidated_ckpt_path(
ckpt_dir=PTH_65b, mp_rank=local_rank, mp_size=MODEL_PARALLEL_SIZE
)
state_dict = torch.load(ckpt_path)
# Build a local LLaMA with no parallelism
model = build_model(...)
sharded_state_dict = build_distributed_state_dict_from_consolidated(
model, state_dict, model_parallel_world_size=MODEL_PARALLEL_SIZE,
)
# Wrap model with PT-native FSDP and load
model = FSDP(model)
FSDP.set_state_dict_type(StateDictType.SHARDED_STATE_DICT)
model.load_state_dict(sharded_state_dict)
Code Reference
Source Location
| File |
Lines |
Description
|
examples/large_models/tp_llama/checkpoint_converter.py |
L1-195 |
Full module (195 lines)
|
examples/large_models/tp_llama/checkpoint_converter.py |
L15-19 |
_verify_fqn_across_ranks(fqn, grp_gloo) -- cross-rank FQN consistency check
|
examples/large_models/tp_llama/checkpoint_converter.py |
L21-27 |
_all_gather_into_list(data_tensor, model_parallel_group) -- gather shards from all ranks
|
examples/large_models/tp_llama/checkpoint_converter.py |
L30-41 |
_is_tp_sharded(fqn) -- detect tensor parallel sharded parameters
|
examples/large_models/tp_llama/checkpoint_converter.py |
L43-90 |
_unshard_param(...) -- reconstruct full tensor from row/column shards
|
examples/large_models/tp_llama/checkpoint_converter.py |
L93-194 |
build_distributed_state_dict_from_consolidated(...) -- main API
|
Signature
def _verify_fqn_across_ranks(fqn, grp_gloo):
"""
Verify that all ranks are processing the same fully qualified name.
Uses dist.all_gather_object to collect FQNs from all ranks and
asserts they are identical.
Args:
fqn (str): Fully qualified parameter name.
grp_gloo (ProcessGroup): Gloo process group for communication.
"""
...
def _all_gather_into_list(data_tensor, model_parallel_group):
"""
All-gather a tensor from all ranks in the model parallel group.
Args:
data_tensor (Tensor): Local tensor shard.
model_parallel_group (ProcessGroup): Model parallel process group.
Returns:
list[Tensor]: List of gathered tensors from all ranks (on CUDA).
"""
...
def _is_tp_sharded(fqn: str) -> bool:
"""
Determine if a parameter is tensor-parallel sharded by FQN inspection.
Returns True if fqn contains 'attention', 'feed_forward', 'output',
or 'tok_embeddings'.
Args:
fqn (str): Fully qualified parameter name.
Returns:
bool: Whether the parameter is TP-sharded.
"""
...
def _unshard_param(
ref_state_dict, fqn, model_parallel_group, grp_gloo,
data_tensor, tp_sharded_shape,
):
"""
Reconstruct a full tensor from row-wise or column-wise shards.
For row-wise sharding: reshapes, all-gathers, and torch.vstack.
For column-wise sharding: reshapes, all-gathers, and torch.column_stack.
Args:
ref_state_dict: Reference model state dict for shape validation.
fqn (str): Fully qualified parameter name.
model_parallel_group: MP process group for all-gather.
grp_gloo: Gloo process group for FQN verification.
data_tensor (Tensor): Local shard tensor.
tp_sharded_shape: Shape of the local shard.
Returns:
tuple: (reconstructed_tensor, full_shape).
"""
...
def build_distributed_state_dict_from_consolidated(
model: nn.Module,
consolidated_state_dict: Dict[str, Tensor],
model_parallel_world_size: int,
offload_to_cpu: bool = False,
use_dtensor: bool = False,
) -> Dict[str, Union[Tensor, DTensor, ShardedTensor]]:
"""
Main API: convert fairscale consolidated checkpoint to PT-D sharded state dict.
Iterates over the consolidated state dict, unshards TP-sharded parameters,
and chunks each tensor into ShardedTensor or DTensor format compatible
with FSDP sharded_state_dict loading.
Args:
model (nn.Module): Model with no parallelism applied.
consolidated_state_dict (Dict[str, Tensor]): Fairscale consolidated checkpoint.
model_parallel_world_size (int): Number of MP ranks in the consolidated checkpoint.
offload_to_cpu (bool): Whether to offload to CPU. Default: False.
use_dtensor (bool): Use DTensor instead of ShardedTensor. Default: False.
Returns:
Dict[str, Union[Tensor, DTensor, ShardedTensor]]: PT-D compliant state dict.
"""
...
Import
# Module imports
import torch
import torch.distributed as dist
from torch import nn, Tensor
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed.fsdp._fsdp_extensions import (
_ext_chunk_dtensor,
_ext_chunk_tensor,
)
I/O Contract
| Function |
Input |
Output |
Notes
|
_verify_fqn_across_ranks(fqn, grp_gloo) |
FQN string, Gloo process group |
None (asserts) |
Raises AssertionError if ranks disagree
|
_all_gather_into_list(data_tensor, group) |
Local Tensor, MP process group |
list[Tensor] on CUDA |
Allocates zeros_like for each rank
|
_is_tp_sharded(fqn) |
FQN string |
bool |
Substring matching against attention, feed_forward, output, tok_embeddings
|
_unshard_param(...) |
Reference state dict, FQN, groups, shard tensor, shape |
tuple(Tensor, Shape) |
Uses vstack for row-wise, column_stack for col-wise
|
build_distributed_state_dict_from_consolidated(...) |
nn.Module, consolidated Dict[str, Tensor], MP world size, optional flags |
Dict[str, Union[Tensor, DTensor, ShardedTensor]] |
Skips rope.freqs buffers; creates DeviceMesh if use_dtensor=True
|
Usage Examples
Example 1: TP Shard Detection
# From checkpoint_converter.py L30-41: _is_tp_sharded() checks FQN patterns
def _is_tp_sharded(fqn: str) -> bool:
return (
"attention" in fqn
or "feed_forward" in fqn
or "output" in fqn
or "tok_embeddings" in fqn
)
# Examples:
_is_tp_sharded("layers.0.attention.wq.weight") # True
_is_tp_sharded("layers.0.feed_forward.w1.weight") # True
_is_tp_sharded("norm.weight") # False
Example 2: Row-wise and Column-wise Unshard
# From checkpoint_converter.py L43-90: _unshard_param() reconstructs full tensors
# Row-wise case: ref_shape[0] != tp_sharded_shape[0]
# -> reshape, all_gather, torch.vstack
# Column-wise case: ref_shape[1] != tp_sharded_shape[1]
# -> reshape, all_gather, torch.column_stack
# Non-sharded case: shapes match exactly
Example 3: Full Conversion Pipeline
# From checkpoint_converter.py L93-194: Main conversion flow
def build_distributed_state_dict_from_consolidated(
model, consolidated_state_dict, model_parallel_world_size, ...
):
dist_state_dict = {}
ref_state_dict = model.state_dict()
grp_gloo = dist.new_group(backend="gloo")
mesh = DeviceMesh(...) if use_dtensor else None
model_parallel_group, _ = dist.new_subgroups(
group_size=model_parallel_world_size
)
for fqn, tensor in consolidated_state_dict.items():
if "rope.freqs" in fqn:
dist_state_dict[fqn] = tensor.clone()
continue
if _is_tp_sharded(fqn):
tensor, _ = _unshard_param(
ref_state_dict, fqn, model_parallel_group,
grp_gloo, tensor, tensor.shape,
)
if use_dtensor:
tensor = _ext_chunk_dtensor(tensor=tensor.contiguous(), ...)
else:
tensor = _ext_chunk_tensor(tensor=tensor.contiguous(), ...)
dist_state_dict[fqn] = tensor
return dist_state_dict
Related Pages