Principle:Microsoft DeepSpeedExamples ZeRO Model Configuration
Sources
- Doc: HuggingFace AutoConfig -- huggingface.co/docs/transformers/model_doc/auto
- Doc: HuggingFace AutoTokenizer -- huggingface.co/docs/transformers/model_doc/auto#AutoTokenizer
Domains
- NLP
- Model_Architecture
- Inference
Overview
A configuration pattern for initializing model metadata and tokenizer settings before distributed weight loading.
Description
Before initializing a model in ZeRO-Inference, the model configuration and tokenizer must be loaded independently of the model weights. This separation is essential because:
- Memory estimation: The configuration provides the architectural parameters (hidden size, number of layers, number of attention heads) needed to calculate memory requirements and configure ZeRO Stage 3 buffer sizes before any weights are loaded.
- Config overrides: Some models require configuration modifications. For example, OPT-175B's configuration is not publicly available on HuggingFace, so it must be constructed by loading OPT-66B's configuration and overriding the dimension parameters to match the 175B variant.
- Tokenizer compatibility: The tokenizer must be loaded from a compatible model when the exact model weights are unavailable. OPT-175B uses the OPT-66B tokenizer since the 175B variant is not hosted on HuggingFace.
- Model type detection: The configuration's
model_typeattribute determines which model class to instantiate (BloomForCausalLM, OPTForCausalLM, LlamaForCausalLM, or AutoModelForCausalLM for Mixtral).
Configuration Flow
The configuration loading follows this sequence:
- Load
AutoConfigfrom the HuggingFace model identifier (or override source for special cases). - Apply any necessary dimension overrides (e.g., OPT-175B).
- Ensure
model_typeis correctly set (e.g., explicitly set to'bloom'for BLOOM models). - Load the tokenizer from the model identifier (or compatible substitute).
- Set
pad_token = eos_tokenfor consistent padding behavior.
Supported Model Types
| Model Type | Config Source | Tokenizer Source | Model Class |
|---|---|---|---|
| OPT (up to 66B) | Direct from HuggingFace | Direct from HuggingFace | OPTForCausalLM
|
| OPT-175B | facebook/opt-66b + overrides |
facebook/opt-66b |
OPTForCausalLM
|
| BLOOM | Direct from HuggingFace | Direct from HuggingFace | BloomForCausalLM
|
| LLaMA 2 | Direct from HuggingFace | Direct from HuggingFace | LlamaForCausalLM
|
| Mixtral | Direct from HuggingFace (trust_remote_code=True) |
Direct from HuggingFace | AutoModelForCausalLM
|
Theoretical Basis
Model configuration precedes weight loading in distributed settings because memory must be pre-allocated across devices. In ZeRO Stage 3, the DeepSpeed configuration requires knowledge of the model's hidden size to set:
stage3_prefetch_bucket_size:2 * hidden_size * hidden_size-- the size of prefetch buffers for parameter gathering.stage3_param_persistence_threshold:hidden_size-- parameters smaller than this threshold are kept on all ranks.stage3_max_live_parameters:2 * hidden_size * hidden_size-- the maximum number of parameters materialized at any time.
These values directly depend on architectural dimensions extracted from the configuration, making configuration loading a prerequisite for DeepSpeed initialization.
OPT-175B Configuration Derivation
The OPT-175B model follows the OPT architecture scaling conventions. Its configuration is derived from OPT-66B with the following overrides:
| Parameter | OPT-66B Value | OPT-175B Value |
|---|---|---|
hidden_size |
9216 | 12288 |
word_embed_proj_dim |
9216 | 12288 |
ffn_dim |
36864 (9216 * 4) | 49152 (12288 * 4) |
num_attention_heads |
72 | 96 |
num_hidden_layers |
64 | 96 |
Key Configuration Attributes
| Attribute | Type | Usage in ZeRO-Inference |
|---|---|---|
hidden_size |
int |
Determines ZeRO Stage 3 buffer sizes; used for memory estimation |
num_hidden_layers |
int |
Used for KV cache size calculation: cache = 2 * batch * seq_len * layers * hidden
|
num_attention_heads |
int |
Determines attention computation shape |
vocab_size |
int |
Used in total model size estimation (embedding layer) |
model_type |
str |
Determines which model class and NVMe buffer configuration to use |
torch_dtype |
torch.dtype |
Determines FP16 vs BF16 inference; defaults to torch.float16 if not set
|