Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pytorch Serve Llama2 Checkpoint Converter

From Leeroopedia

Overview

build_distributed_state_dict_from_consolidated is the main API function in the checkpoint converter module for tensor-parallel Llama models. It converts fairscale consolidated checkpoints into PyTorch Distributed (PT-D) compliant sharded state dictionaries suitable for loading into FSDP-wrapped models. The module provides helper functions for cross-rank verification, all-gather operations, tensor parallel shard detection, and row/column-wise unshard reconstruction.

Field Value
Implementation Name Llama2_Checkpoint_Converter
Type Utility Module
Workflow Tensor_Parallel_Checkpoint_Conversion
Domains LLM_Serving, Distributed_Computing
Knowledge Sources Pytorch_Serve
Last Updated 2026-02-13 18:52 GMT

Description

The checkpoint converter module bridges the gap between Meta's fairscale-based Llama checkpoints (consolidated format) and PyTorch's native distributed tensor abstractions (ShardedTensor and DTensor). This is essential for loading pre-trained Llama weights into a model wrapped with PyTorch FSDP (Fully Sharded Data Parallel).

Key Responsibilities

  • FQN Verification: _verify_fqn_across_ranks() uses dist.all_gather_object() to ensure all ranks process the same fully qualified parameter name
  • All-Gather: _all_gather_into_list() gathers tensor shards from all ranks in a model parallel group onto GPU
  • Shard Detection: _is_tp_sharded() determines if a parameter is tensor-parallel sharded by inspecting the FQN for attention, feed_forward, output, or tok_embeddings substrings
  • Unshard Reconstruction: _unshard_param() reconstructs full tensors from row-wise or column-wise shards using torch.vstack or torch.column_stack respectively
  • State Dict Building: build_distributed_state_dict_from_consolidated() iterates over the consolidated state dict, unshards tensor-parallel parameters, and chunks them into ShardedTensor or DTensor format for FSDP

Usage

from checkpoint_converter import build_distributed_state_dict_from_consolidated
# Convert a fairscale consolidated checkpoint to PT-D sharded format
MODEL_PARALLEL_SIZE = 8
ckpt_path = get_consolidated_ckpt_path(
    ckpt_dir=PTH_65b, mp_rank=local_rank, mp_size=MODEL_PARALLEL_SIZE
)
state_dict = torch.load(ckpt_path)

# Build a local LLaMA with no parallelism
model = build_model(...)

sharded_state_dict = build_distributed_state_dict_from_consolidated(
    model, state_dict, model_parallel_world_size=MODEL_PARALLEL_SIZE,
)

# Wrap model with PT-native FSDP and load
model = FSDP(model)
FSDP.set_state_dict_type(StateDictType.SHARDED_STATE_DICT)
model.load_state_dict(sharded_state_dict)

Code Reference

Source Location

File Lines Description
examples/large_models/tp_llama/checkpoint_converter.py L1-195 Full module (195 lines)
examples/large_models/tp_llama/checkpoint_converter.py L15-19 _verify_fqn_across_ranks(fqn, grp_gloo) -- cross-rank FQN consistency check
examples/large_models/tp_llama/checkpoint_converter.py L21-27 _all_gather_into_list(data_tensor, model_parallel_group) -- gather shards from all ranks
examples/large_models/tp_llama/checkpoint_converter.py L30-41 _is_tp_sharded(fqn) -- detect tensor parallel sharded parameters
examples/large_models/tp_llama/checkpoint_converter.py L43-90 _unshard_param(...) -- reconstruct full tensor from row/column shards
examples/large_models/tp_llama/checkpoint_converter.py L93-194 build_distributed_state_dict_from_consolidated(...) -- main API

Signature

def _verify_fqn_across_ranks(fqn, grp_gloo):
    """
    Verify that all ranks are processing the same fully qualified name.

    Uses dist.all_gather_object to collect FQNs from all ranks and
    asserts they are identical.

    Args:
        fqn (str): Fully qualified parameter name.
        grp_gloo (ProcessGroup): Gloo process group for communication.
    """
    ...

def _all_gather_into_list(data_tensor, model_parallel_group):
    """
    All-gather a tensor from all ranks in the model parallel group.

    Args:
        data_tensor (Tensor): Local tensor shard.
        model_parallel_group (ProcessGroup): Model parallel process group.

    Returns:
        list[Tensor]: List of gathered tensors from all ranks (on CUDA).
    """
    ...

def _is_tp_sharded(fqn: str) -> bool:
    """
    Determine if a parameter is tensor-parallel sharded by FQN inspection.

    Returns True if fqn contains 'attention', 'feed_forward', 'output',
    or 'tok_embeddings'.

    Args:
        fqn (str): Fully qualified parameter name.

    Returns:
        bool: Whether the parameter is TP-sharded.
    """
    ...

def _unshard_param(
    ref_state_dict, fqn, model_parallel_group, grp_gloo,
    data_tensor, tp_sharded_shape,
):
    """
    Reconstruct a full tensor from row-wise or column-wise shards.

    For row-wise sharding: reshapes, all-gathers, and torch.vstack.
    For column-wise sharding: reshapes, all-gathers, and torch.column_stack.

    Args:
        ref_state_dict: Reference model state dict for shape validation.
        fqn (str): Fully qualified parameter name.
        model_parallel_group: MP process group for all-gather.
        grp_gloo: Gloo process group for FQN verification.
        data_tensor (Tensor): Local shard tensor.
        tp_sharded_shape: Shape of the local shard.

    Returns:
        tuple: (reconstructed_tensor, full_shape).
    """
    ...

def build_distributed_state_dict_from_consolidated(
    model: nn.Module,
    consolidated_state_dict: Dict[str, Tensor],
    model_parallel_world_size: int,
    offload_to_cpu: bool = False,
    use_dtensor: bool = False,
) -> Dict[str, Union[Tensor, DTensor, ShardedTensor]]:
    """
    Main API: convert fairscale consolidated checkpoint to PT-D sharded state dict.

    Iterates over the consolidated state dict, unshards TP-sharded parameters,
    and chunks each tensor into ShardedTensor or DTensor format compatible
    with FSDP sharded_state_dict loading.

    Args:
        model (nn.Module): Model with no parallelism applied.
        consolidated_state_dict (Dict[str, Tensor]): Fairscale consolidated checkpoint.
        model_parallel_world_size (int): Number of MP ranks in the consolidated checkpoint.
        offload_to_cpu (bool): Whether to offload to CPU. Default: False.
        use_dtensor (bool): Use DTensor instead of ShardedTensor. Default: False.

    Returns:
        Dict[str, Union[Tensor, DTensor, ShardedTensor]]: PT-D compliant state dict.
    """
    ...

Import

# Module imports
import torch
import torch.distributed as dist
from torch import nn, Tensor
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed.fsdp._fsdp_extensions import (
    _ext_chunk_dtensor,
    _ext_chunk_tensor,
)

I/O Contract

Function Input Output Notes
_verify_fqn_across_ranks(fqn, grp_gloo) FQN string, Gloo process group None (asserts) Raises AssertionError if ranks disagree
_all_gather_into_list(data_tensor, group) Local Tensor, MP process group list[Tensor] on CUDA Allocates zeros_like for each rank
_is_tp_sharded(fqn) FQN string bool Substring matching against attention, feed_forward, output, tok_embeddings
_unshard_param(...) Reference state dict, FQN, groups, shard tensor, shape tuple(Tensor, Shape) Uses vstack for row-wise, column_stack for col-wise
build_distributed_state_dict_from_consolidated(...) nn.Module, consolidated Dict[str, Tensor], MP world size, optional flags Dict[str, Union[Tensor, DTensor, ShardedTensor]] Skips rope.freqs buffers; creates DeviceMesh if use_dtensor=True

Usage Examples

Example 1: TP Shard Detection

# From checkpoint_converter.py L30-41: _is_tp_sharded() checks FQN patterns
def _is_tp_sharded(fqn: str) -> bool:
    return (
        "attention" in fqn
        or "feed_forward" in fqn
        or "output" in fqn
        or "tok_embeddings" in fqn
    )

# Examples:
_is_tp_sharded("layers.0.attention.wq.weight")     # True
_is_tp_sharded("layers.0.feed_forward.w1.weight")  # True
_is_tp_sharded("norm.weight")                       # False

Example 2: Row-wise and Column-wise Unshard

# From checkpoint_converter.py L43-90: _unshard_param() reconstructs full tensors
# Row-wise case: ref_shape[0] != tp_sharded_shape[0]
#   -> reshape, all_gather, torch.vstack
# Column-wise case: ref_shape[1] != tp_sharded_shape[1]
#   -> reshape, all_gather, torch.column_stack
# Non-sharded case: shapes match exactly

Example 3: Full Conversion Pipeline

# From checkpoint_converter.py L93-194: Main conversion flow
def build_distributed_state_dict_from_consolidated(
    model, consolidated_state_dict, model_parallel_world_size, ...
):
    dist_state_dict = {}
    ref_state_dict = model.state_dict()
    grp_gloo = dist.new_group(backend="gloo")
    mesh = DeviceMesh(...) if use_dtensor else None
    model_parallel_group, _ = dist.new_subgroups(
        group_size=model_parallel_world_size
    )

    for fqn, tensor in consolidated_state_dict.items():
        if "rope.freqs" in fqn:
            dist_state_dict[fqn] = tensor.clone()
            continue
        if _is_tp_sharded(fqn):
            tensor, _ = _unshard_param(
                ref_state_dict, fqn, model_parallel_group,
                grp_gloo, tensor, tensor.shape,
            )
        if use_dtensor:
            tensor = _ext_chunk_dtensor(tensor=tensor.contiguous(), ...)
        else:
            tensor = _ext_chunk_tensor(tensor=tensor.contiguous(), ...)
        dist_state_dict[fqn] = tensor
    return dist_state_dict

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment