Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft DeepSpeedExamples Load Model SuperOffload

From Leeroopedia


Metadata

Field Value
Page Type Implementation
Title Load_Model_SuperOffload
Repository Microsoft/DeepSpeedExamples
Type Direct Function
Code Reference File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 127-158
Import Direct functions in finetune_zero3.py
Related Principle Principle:Microsoft_DeepSpeedExamples_Large_Model_Loading

Overview

Concrete tool for loading HuggingFace causal language models with Flash Attention and gradient checkpointing for SuperOffload fine-tuning. Provides three functions: load_tokenizer, load_model, and setup_model_training.

Function: load_tokenizer

Signature

def load_tokenizer(model_name: str, logger: logging.Logger) -> AutoTokenizer:

Code Reference: File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 127-134

Description

Loads a HuggingFace tokenizer by model name and ensures the pad_token is set. If the tokenizer does not define a pad_token, it is set to eos_token.

Implementation

def load_tokenizer(model_name: str, logger: logging.Logger) -> AutoTokenizer:
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        logger.debug(f"Set pad_token to eos_token: {tokenizer.eos_token}")

    return tokenizer

I/O Contract

Parameter Type Description
model_name str HuggingFace model identifier (e.g., "meta-llama/Llama-3.1-8B")
logger logging.Logger Logger instance for debug output

Returns: AutoTokenizer with pad_token guaranteed to be set.

Function: load_model

Signature

def load_model(
    model_name: str,
    attn_implementation: str,
    logger: logging.Logger
) -> AutoModelForCausalLM:

Code Reference: File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 137-147

Description

Loads a HuggingFace causal language model with BF16 precision and the specified attention implementation.

Implementation

def load_model(model_name: str, attn_implementation: str,
               logger: logging.Logger) -> AutoModelForCausalLM:
    logger.debug(f"Loading model: {model_name}")
    logger.debug(f"Attention implementation: {attn_implementation}")

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        attn_implementation=attn_implementation
    )

    return model

I/O Contract

Parameter Type Description Default
model_name str HuggingFace model identifier (required)
attn_implementation str Attention backend: "eager", "sdpa", or "flash_attention_2" (required)
logger logging.Logger Logger instance (required)

Returns: AutoModelForCausalLM loaded in BF16 precision with the specified attention implementation.

Attention Implementation Options

Implementation Description Memory Speed
eager Standard PyTorch attention O(N^2) Baseline
sdpa PyTorch 2.0+ Scaled Dot-Product Attention O(N) ~2x faster
flash_attention_2 Flash Attention 2 (requires flash-attn package) O(N) ~2-4x faster

Function: setup_model_training

Signature

def setup_model_training(
    model: torch.nn.Module,
    use_activation_checkpointing: bool = True,
    logger: logging.Logger = None
) -> None:

Code Reference: File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 150-158

Description

Configures a loaded model for training by enabling gradient checkpointing and disabling the KV cache. This function modifies the model in-place.

Implementation

def setup_model_training(model: torch.nn.Module,
                         use_activation_checkpointing: bool = True,
                         logger: logging.Logger = None) -> None:
    if use_activation_checkpointing:
        if logger:
            logger.debug("Enabling gradient checkpointing...")
        if hasattr(model.config, 'use_cache'):
            model.config.use_cache = False
        model.gradient_checkpointing_enable(
            gradient_checkpointing_kwargs={"use_reentrant": False}
        )

I/O Contract

Parameter Type Description Default
model torch.nn.Module The loaded model to configure (required)
use_activation_checkpointing bool Whether to enable gradient checkpointing True
logger logging.Logger Optional logger instance None

Returns: None (modifies model in-place)

Side effects:

  • Sets model.config.use_cache = False (disables KV cache)
  • Calls model.gradient_checkpointing_enable() with use_reentrant=False

Helper Function: detect_moe_model

Code Reference: File: training/DeepSpeed-SuperOffload/finetune_zero3.py, Lines 106-115

def detect_moe_model(model: AutoModelForCausalLM, model_name: str) -> bool:
    moe_config_attrs = [
        'num_local_experts', 'moe_layers', 'num_experts',
        'expert_capacity', 'router_aux_loss_coef'
    ]

    for attr in moe_config_attrs:
        if hasattr(model.config, attr):
            return True
    return False

This function checks whether the loaded model is a Mixture-of-Experts (MoE) architecture by inspecting the model config for MoE-specific attributes.

Invocation in Main Script

In the main() function (Lines 239-245):

tokenizer = load_tokenizer(args.model_name, logger)
model = load_model(args.model_name, args.attn_implementation, logger)
if args.leaf_module:
    from deepspeed.utils import set_z3_leaf_modules
    logger.debug(f"Setting leaf_module to: {args.leaf_module}")
    set_z3_leaf_modules(model, [args.leaf_module])
setup_model_training(model, args.activation_checkpointing, logger)

Usage Example

import logging

logger = logging.getLogger("finetune_zero3")

# Step 1: Load tokenizer
tokenizer = load_tokenizer("meta-llama/Llama-3.1-8B", logger)

# Step 2: Load model with Flash Attention 2 in BF16
model = load_model("meta-llama/Llama-3.1-8B", "flash_attention_2", logger)

# Step 3: Configure for training (enable gradient checkpointing, disable KV cache)
setup_model_training(model, use_activation_checkpointing=True, logger=logger)

# Model is now ready for DeepSpeed initialization

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment