Implementation:Huggingface Optimum MetaAwareMethodsPatcher
Overview
A context manager class that monkey-patches nn.Linear and nn.Embedding to force all parameter allocation onto PyTorch's meta device during model construction. Also patches forward methods to handle meta tensor inputs during FX tracing.
Source
| Property | Value |
|---|---|
| File | optimum/fx/parallelization/utils.py
|
| Lines | L214-295 |
| Module | optimum.fx.parallelization.utils
|
Class Definition
class MetaAwareMethodsPatcher:
def __enter__(self):
# Patches nn.Linear and nn.Embedding init and forward
...
def __exit__(self, exc_type, exc_val, exc_tb):
# Unpatches __init__ only (forward patches remain)
...
Import
from optimum.fx.parallelization.utils import MetaAwareMethodsPatcher
Patched Methods
The context manager patches four methods across two PyTorch module classes:
| Original Method | Replacement | Restored on Exit |
|---|---|---|
nn.Linear.__init__ |
Forces device="meta" via meta_init wrapper |
Yes |
nn.Embedding.__init__ |
Forces device="meta" via meta_init wrapper |
Yes |
nn.Linear.forward |
meta_aware_linear_forward (L170-181) | No |
nn.Embedding.forward |
meta_aware_embedding_forward (L184-211) | No |
Note: The forward patches are not restored on exit. This is intentional, as the patched forward methods are needed during subsequent FX tracing to handle meta tensor shape propagation correctly.
Helper Function: meta_init
# Located at optimum/fx/parallelization/utils.py L161-167
def meta_init(init_fn):
"""Wrapper that intercepts __init__ calls and forces device='meta'."""
@functools.wraps(init_fn)
def wrapper(*args, **kwargs):
kwargs["device"] = "meta"
return init_fn(*args, **kwargs)
return wrapper
Patched Forward Methods
meta_aware_linear_forward (L170-181)
Handles nn.Linear.forward when parameters are on the meta device. Instead of performing the actual matrix multiplication, it computes the correct output shape and returns a meta tensor with that shape.
meta_aware_embedding_forward (L184-211)
Handles nn.Embedding.forward when the embedding table is on the meta device. Returns a meta tensor with the correct output shape (input shape + embedding dimension) and appropriate dtype.
Behavior
On Enter (__enter__)
- Save references to the original
__init__andforwardmethods for both nn.Linear and nn.Embedding. - Replace
__init__methods with meta_init-wrapped versions. - Replace
forwardmethods with meta-aware versions.
On Exit (__exit__)
- Restore only the original
__init__methods for both nn.Linear and nn.Embedding. - Leave
forwardpatches in place for FX tracing compatibility.
Example Usage
from transformers import AutoConfig, AutoModelForCausalLM
from optimum.fx.parallelization.utils import MetaAwareMethodsPatcher
config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
with MetaAwareMethodsPatcher():
# All nn.Linear and nn.Embedding modules will be on "meta" device
model = AutoModelForCausalLM.from_config(config)
# Verify: all parameters are meta tensors
for name, param in model.named_parameters():
assert param.device.type == "meta"