Implementation:Microsoft DeepSpeedExamples Load Model SuperOffload
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation |
| Title | Load_Model_SuperOffload |
| Repository | Microsoft/DeepSpeedExamples |
| Type | Direct Function |
| Code Reference | File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 127-158
|
| Import | Direct functions in finetune_zero3.py
|
| Related Principle | Principle:Microsoft_DeepSpeedExamples_Large_Model_Loading |
Overview
Concrete tool for loading HuggingFace causal language models with Flash Attention and gradient checkpointing for SuperOffload fine-tuning. Provides three functions: load_tokenizer, load_model, and setup_model_training.
Function: load_tokenizer
Signature
def load_tokenizer(model_name: str, logger: logging.Logger) -> AutoTokenizer:
Code Reference: File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 127-134
Description
Loads a HuggingFace tokenizer by model name and ensures the pad_token is set. If the tokenizer does not define a pad_token, it is set to eos_token.
Implementation
def load_tokenizer(model_name: str, logger: logging.Logger) -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.debug(f"Set pad_token to eos_token: {tokenizer.eos_token}")
return tokenizer
I/O Contract
| Parameter | Type | Description |
|---|---|---|
model_name |
str |
HuggingFace model identifier (e.g., "meta-llama/Llama-3.1-8B")
|
logger |
logging.Logger |
Logger instance for debug output |
Returns: AutoTokenizer with pad_token guaranteed to be set.
Function: load_model
Signature
def load_model(
model_name: str,
attn_implementation: str,
logger: logging.Logger
) -> AutoModelForCausalLM:
Code Reference: File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 137-147
Description
Loads a HuggingFace causal language model with BF16 precision and the specified attention implementation.
Implementation
def load_model(model_name: str, attn_implementation: str,
logger: logging.Logger) -> AutoModelForCausalLM:
logger.debug(f"Loading model: {model_name}")
logger.debug(f"Attention implementation: {attn_implementation}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation
)
return model
I/O Contract
| Parameter | Type | Description | Default |
|---|---|---|---|
model_name |
str |
HuggingFace model identifier | (required) |
attn_implementation |
str |
Attention backend: "eager", "sdpa", or "flash_attention_2" |
(required) |
logger |
logging.Logger |
Logger instance | (required) |
Returns: AutoModelForCausalLM loaded in BF16 precision with the specified attention implementation.
Attention Implementation Options
| Implementation | Description | Memory | Speed |
|---|---|---|---|
eager |
Standard PyTorch attention | O(N^2) | Baseline |
sdpa |
PyTorch 2.0+ Scaled Dot-Product Attention | O(N) | ~2x faster |
flash_attention_2 |
Flash Attention 2 (requires flash-attn package) | O(N) | ~2-4x faster |
Function: setup_model_training
Signature
def setup_model_training(
model: torch.nn.Module,
use_activation_checkpointing: bool = True,
logger: logging.Logger = None
) -> None:
Code Reference: File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 150-158
Description
Configures a loaded model for training by enabling gradient checkpointing and disabling the KV cache. This function modifies the model in-place.
Implementation
def setup_model_training(model: torch.nn.Module,
use_activation_checkpointing: bool = True,
logger: logging.Logger = None) -> None:
if use_activation_checkpointing:
if logger:
logger.debug("Enabling gradient checkpointing...")
if hasattr(model.config, 'use_cache'):
model.config.use_cache = False
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
I/O Contract
| Parameter | Type | Description | Default |
|---|---|---|---|
model |
torch.nn.Module |
The loaded model to configure | (required) |
use_activation_checkpointing |
bool |
Whether to enable gradient checkpointing | True
|
logger |
logging.Logger |
Optional logger instance | None
|
Returns: None (modifies model in-place)
Side effects:
- Sets
model.config.use_cache = False(disables KV cache) - Calls
model.gradient_checkpointing_enable()withuse_reentrant=False
Helper Function: detect_moe_model
Code Reference: File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 106-115
def detect_moe_model(model: AutoModelForCausalLM, model_name: str) -> bool:
moe_config_attrs = [
'num_local_experts', 'moe_layers', 'num_experts',
'expert_capacity', 'router_aux_loss_coef'
]
for attr in moe_config_attrs:
if hasattr(model.config, attr):
return True
return False
This function checks whether the loaded model is a Mixture-of-Experts (MoE) architecture by inspecting the model config for MoE-specific attributes.
Invocation in Main Script
In the main() function (Lines 239-245):
tokenizer = load_tokenizer(args.model_name, logger)
model = load_model(args.model_name, args.attn_implementation, logger)
if args.leaf_module:
from deepspeed.utils import set_z3_leaf_modules
logger.debug(f"Setting leaf_module to: {args.leaf_module}")
set_z3_leaf_modules(model, [args.leaf_module])
setup_model_training(model, args.activation_checkpointing, logger)
Usage Example
import logging
logger = logging.getLogger("finetune_zero3")
# Step 1: Load tokenizer
tokenizer = load_tokenizer("meta-llama/Llama-3.1-8B", logger)
# Step 2: Load model with Flash Attention 2 in BF16
model = load_model("meta-llama/Llama-3.1-8B", "flash_attention_2", logger)
# Step 3: Configure for training (enable gradient checkpointing, disable KV cache)
setup_model_training(model, use_activation_checkpointing=True, logger=logger)
# Model is now ready for DeepSpeed initialization