Implementation:Bigscience workshop Petals AutoDistributedModelForCausalLM From Pretrained
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, NLP, Model_Loading |
| Last Updated | 2026-02-09 14:00 GMT |
Overview
Concrete tool for loading distributed causal language models provided by the Petals library.
Description
AutoDistributedModelForCausalLM is an auto-class that resolves the correct distributed model implementation (Llama, Bloom, Falcon, or Mixtral) based on the model configuration. It delegates to _AutoDistributedBase.from_pretrained, which looks up the model type in a class mapping populated at import time by each model subpackage.
The actual loading logic lives in FromPretrainedMixin.from_pretrained, which:
- Forces low_cpu_mem_usage=True to minimize RAM usage during shard loading
- Sets torch_dtype="auto" if not specified
- Calls the parent HuggingFace from_pretrained which triggers __init__ on the distributed model class
- The distributed model's __init__ replaces the standard transformer block list with a RemoteSequential module connected to the hivemind DHT
Usage
Import this class when you need to load any supported large language model for distributed text generation. This is the primary entry point for client-side usage of Petals.
Code Reference
Source Location
- Repository: petals
- File: src/petals/utils/auto_config.py (AutoDistributedModelForCausalLM at L90-92, _AutoDistributedBase.from_pretrained at L32-52)
- File: src/petals/client/from_pretrained.py (FromPretrainedMixin.from_pretrained at L17-39)
Signature
class AutoDistributedModelForCausalLM(_AutoDistributedBase):
_mapping_field = "model_for_causal_lm"
class _AutoDistributedBase:
@classmethod
def from_pretrained(
cls,
model_name_or_path: Union[str, os.PathLike, None],
*args,
**kwargs,
) -> PreTrainedModel:
"""
Resolves the correct distributed model class from the model config
and delegates to its from_pretrained method.
"""
class FromPretrainedMixin:
@classmethod
def from_pretrained(
cls,
model_name_or_path: Union[str, os.PathLike, None],
*args,
low_cpu_mem_usage: Optional[bool] = None,
**kwargs,
):
"""
Overrides HuggingFace from_pretrained to force low_cpu_mem_usage=True
and torch_dtype="auto" for distributed loading.
"""
Import
from petals import AutoDistributedModelForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_name_or_path | str | Yes | HuggingFace model repository name (e.g. "petals-team/StableBeluga2") |
| torch_dtype | str or torch.dtype | No | Data type for weights; defaults to "auto" which resolves from model config |
| low_cpu_mem_usage | bool | No | Forced to True by Petals to minimize RAM during loading |
| token | Optional[str] | No | HuggingFace authentication token for gated models |
| dht | DHT | No | Pre-existing DHT instance; if None, auto-created from PUBLIC_INITIAL_PEERS |
Outputs
| Name | Type | Description |
|---|---|---|
| model | DistributedModelForCausalLM | Distributed model instance (e.g. DistributedLlamaForCausalLM) with RemoteSequential as transformer layers, local embeddings and LM head loaded |
Usage Examples
Basic Model Loading
from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer
model_name = "petals-team/StableBeluga2"
# Load distributed model - only downloads embeddings and LM head locally
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
# Load tokenizer normally
tokenizer = AutoTokenizer.from_pretrained(model_name)
# The model's transformer layers are RemoteSequential proxies
print(type(model.model.layers)) # <class 'petals.client.remote_sequential.RemoteSequential'>
Loading with Custom DHT
from hivemind import DHT
from petals import AutoDistributedModelForCausalLM
# Connect to a private swarm
dht = DHT(initial_peers=["/ip4/1.2.3.4/tcp/31337/p2p/QmExample..."], start=True)
model = AutoDistributedModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-chat-hf",
dht=dht,
token="hf_YOUR_TOKEN",
)