Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bigscience workshop Petals AutoDistributedModelForCausalLM From Pretrained

From Leeroopedia


Knowledge Sources
Domains Distributed_Computing, NLP, Model_Loading
Last Updated 2026-02-09 14:00 GMT

Overview

Concrete tool for loading distributed causal language models provided by the Petals library.

Description

AutoDistributedModelForCausalLM is an auto-class that resolves the correct distributed model implementation (Llama, Bloom, Falcon, or Mixtral) based on the model configuration. It delegates to _AutoDistributedBase.from_pretrained, which looks up the model type in a class mapping populated at import time by each model subpackage.

The actual loading logic lives in FromPretrainedMixin.from_pretrained, which:

  • Forces low_cpu_mem_usage=True to minimize RAM usage during shard loading
  • Sets torch_dtype="auto" if not specified
  • Calls the parent HuggingFace from_pretrained which triggers __init__ on the distributed model class
  • The distributed model's __init__ replaces the standard transformer block list with a RemoteSequential module connected to the hivemind DHT

Usage

Import this class when you need to load any supported large language model for distributed text generation. This is the primary entry point for client-side usage of Petals.

Code Reference

Source Location

  • Repository: petals
  • File: src/petals/utils/auto_config.py (AutoDistributedModelForCausalLM at L90-92, _AutoDistributedBase.from_pretrained at L32-52)
  • File: src/petals/client/from_pretrained.py (FromPretrainedMixin.from_pretrained at L17-39)

Signature

class AutoDistributedModelForCausalLM(_AutoDistributedBase):
    _mapping_field = "model_for_causal_lm"

class _AutoDistributedBase:
    @classmethod
    def from_pretrained(
        cls,
        model_name_or_path: Union[str, os.PathLike, None],
        *args,
        **kwargs,
    ) -> PreTrainedModel:
        """
        Resolves the correct distributed model class from the model config
        and delegates to its from_pretrained method.
        """

class FromPretrainedMixin:
    @classmethod
    def from_pretrained(
        cls,
        model_name_or_path: Union[str, os.PathLike, None],
        *args,
        low_cpu_mem_usage: Optional[bool] = None,
        **kwargs,
    ):
        """
        Overrides HuggingFace from_pretrained to force low_cpu_mem_usage=True
        and torch_dtype="auto" for distributed loading.
        """

Import

from petals import AutoDistributedModelForCausalLM

I/O Contract

Inputs

Name Type Required Description
model_name_or_path str Yes HuggingFace model repository name (e.g. "petals-team/StableBeluga2")
torch_dtype str or torch.dtype No Data type for weights; defaults to "auto" which resolves from model config
low_cpu_mem_usage bool No Forced to True by Petals to minimize RAM during loading
token Optional[str] No HuggingFace authentication token for gated models
dht DHT No Pre-existing DHT instance; if None, auto-created from PUBLIC_INITIAL_PEERS

Outputs

Name Type Description
model DistributedModelForCausalLM Distributed model instance (e.g. DistributedLlamaForCausalLM) with RemoteSequential as transformer layers, local embeddings and LM head loaded

Usage Examples

Basic Model Loading

from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer

model_name = "petals-team/StableBeluga2"

# Load distributed model - only downloads embeddings and LM head locally
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)

# Load tokenizer normally
tokenizer = AutoTokenizer.from_pretrained(model_name)

# The model's transformer layers are RemoteSequential proxies
print(type(model.model.layers))  # <class 'petals.client.remote_sequential.RemoteSequential'>

Loading with Custom DHT

from hivemind import DHT
from petals import AutoDistributedModelForCausalLM

# Connect to a private swarm
dht = DHT(initial_peers=["/ip4/1.2.3.4/tcp/31337/p2p/QmExample..."], start=True)

model = AutoDistributedModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-chat-hf",
    dht=dht,
    token="hf_YOUR_TOKEN",
)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment