Implementation:Huggingface Optimum Symbolic Trace Wrapper
Overview
This is a Wrapper Doc -- the actual implementation resides in the transformers library. The Optimum FX optimization module depends on transformers.utils.fx.symbolic_trace as a prerequisite for all graph transformations.
Source
External: transformers.utils.fx.symbolic_trace
This function is not part of the Optimum codebase itself. It is imported from the HuggingFace transformers library and used as the entry point for converting PreTrainedModel instances into torch.fx.GraphModule objects suitable for FX graph optimization.
API
transformers.utils.fx.symbolic_trace(
model: PreTrainedModel,
input_names: List[str],
disable_check: bool = False,
...
) -> torch.fx.GraphModule
| Parameter | Type | Description |
|---|---|---|
model |
PreTrainedModel |
The HuggingFace model to trace |
input_names |
List[str] |
Names of the inputs to trace (e.g., ["input_ids", "attention_mask"])
|
disable_check |
bool |
Whether to disable the output check that verifies traced model matches original |
Returns: A torch.fx.GraphModule containing the traced graph IR and executable code.
Import
from transformers.utils.fx import symbolic_trace
How Optimum Uses This
The Optimum FX optimization module uses symbolic_trace as a prerequisite step before applying any graph transformations. The transformers version of symbolic_trace handles HuggingFace-specific patterns that the standard torch.fx.symbolic_trace cannot handle:
- Optional outputs -- Resolves config-dependent return values at trace time
- Config-dependent branches -- Evaluates model configuration flags to select a single execution path
- HuggingFace module conventions -- Properly traces through the nested module hierarchies used by Transformer architectures
Optimum checks whether the required FX features are available via optimum.fx.utils.are_fx_features_available() before attempting any FX operations.
Usage Example
from transformers import AutoModel
from transformers.utils.fx import symbolic_trace
model = AutoModel.from_pretrained("bert-base-uncased")
traced = symbolic_trace(model, input_names=["input_ids", "attention_mask"])
# traced is now a torch.fx.GraphModule ready for optimization
A more complete example showing the typical pattern used in Optimum's test suite:
from transformers import BertModel
from transformers.utils.fx import symbolic_trace
from optimum.fx.optimization import MergeLinears
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
traced = symbolic_trace(
model,
input_names=["input_ids", "attention_mask", "token_type_ids"],
)
# Now apply transformations on the traced GraphModule
transformation = MergeLinears()
transformed_model = transformation(traced)
External Reference
Related
- implements -> Principle:Huggingface_Optimum_Model_Symbolic_Tracing