Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Optimum MetaAwareMethodsPatcher

From Leeroopedia

Overview

A context manager class that monkey-patches nn.Linear and nn.Embedding to force all parameter allocation onto PyTorch's meta device during model construction. Also patches forward methods to handle meta tensor inputs during FX tracing.

Source

Property Value
File optimum/fx/parallelization/utils.py
Lines L214-295
Module optimum.fx.parallelization.utils

Class Definition

class MetaAwareMethodsPatcher:
    def __enter__(self):
        # Patches nn.Linear and nn.Embedding init and forward
        ...
    def __exit__(self, exc_type, exc_val, exc_tb):
        # Unpatches __init__ only (forward patches remain)
        ...

Import

from optimum.fx.parallelization.utils import MetaAwareMethodsPatcher

Patched Methods

The context manager patches four methods across two PyTorch module classes:

Original Method Replacement Restored on Exit
nn.Linear.__init__ Forces device="meta" via meta_init wrapper Yes
nn.Embedding.__init__ Forces device="meta" via meta_init wrapper Yes
nn.Linear.forward meta_aware_linear_forward (L170-181) No
nn.Embedding.forward meta_aware_embedding_forward (L184-211) No

Note: The forward patches are not restored on exit. This is intentional, as the patched forward methods are needed during subsequent FX tracing to handle meta tensor shape propagation correctly.

Helper Function: meta_init

# Located at optimum/fx/parallelization/utils.py L161-167
def meta_init(init_fn):
    """Wrapper that intercepts __init__ calls and forces device='meta'."""
    @functools.wraps(init_fn)
    def wrapper(*args, **kwargs):
        kwargs["device"] = "meta"
        return init_fn(*args, **kwargs)
    return wrapper

Patched Forward Methods

meta_aware_linear_forward (L170-181)

Handles nn.Linear.forward when parameters are on the meta device. Instead of performing the actual matrix multiplication, it computes the correct output shape and returns a meta tensor with that shape.

meta_aware_embedding_forward (L184-211)

Handles nn.Embedding.forward when the embedding table is on the meta device. Returns a meta tensor with the correct output shape (input shape + embedding dimension) and appropriate dtype.

Behavior

On Enter (__enter__)

  1. Save references to the original __init__ and forward methods for both nn.Linear and nn.Embedding.
  2. Replace __init__ methods with meta_init-wrapped versions.
  3. Replace forward methods with meta-aware versions.

On Exit (__exit__)

  1. Restore only the original __init__ methods for both nn.Linear and nn.Embedding.
  2. Leave forward patches in place for FX tracing compatibility.

Example Usage

from transformers import AutoConfig, AutoModelForCausalLM
from optimum.fx.parallelization.utils import MetaAwareMethodsPatcher

config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")

with MetaAwareMethodsPatcher():
    # All nn.Linear and nn.Embedding modules will be on "meta" device
    model = AutoModelForCausalLM.from_config(config)

# Verify: all parameters are meta tensors
for name, param in model.named_parameters():
    assert param.device.type == "meta"

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment