Implementation:Huggingface Optimum Get Transformed Nodes
Overview
Implementation of the transformation tracking and validation mechanisms in the Transformation base class. These methods enable inspecting which nodes a transformation modified and verifying correctness.
Source
File: optimum/fx/optimization/transformations.py
APIs
get_transformed_nodes (L161-172)
Returns the list of graph nodes that were modified by this specific transformation instance.
def get_transformed_nodes(self, graph_module: GraphModule) -> List[Node]:
return [node for node in graph_module.graph.nodes if self.transformed(node)]
| Parameter | Type | Description |
|---|---|---|
graph_module |
torch.fx.GraphModule |
The graph module to inspect |
Returns: List[torch.fx.Node] -- Nodes that were transformed by this transformation instance.
transformed (L149-159)
Checks whether a specific node was transformed by this transformation instance.
def transformed(self, node: Node) -> bool:
return self.signature in getattr(node, "transformations", set())
Returns: bool -- True if the node's transformations set contains this transformation's signature.
signature (L127-135)
Property that returns a unique hash identifying this transformation instance.
@property
def signature(self):
attributes_to_use_for_hashing = vars(self)
attributes_to_use_for_hashing[""] = self.__class__
hash_str = "_".join(f"{k}_{hash(v)}" for k, v in attributes_to_use_for_hashing.items())
return hash(hash_str)
Returns: int -- A hash value unique to the combination of this transformation's class and its instance attributes.
mark_as_transformed (L137-147)
Marks a node as having been transformed by this transformation.
def mark_as_transformed(self, node: Node):
node_transformations = getattr(node, "transformations", set())
node_transformations.add(self.signature)
node.transformations = node_transformations
This method is called by concrete transformation implementations during their transform() method to record which nodes they have modified.
__call__ (L108-125)
The callable interface for applying the transformation. Handles linting and recompilation after the transformation is applied.
def __call__(self, graph_module: GraphModule, lint_and_recompile: bool = True) -> GraphModule:
graph_module = self.transform(graph_module)
if lint_and_recompile:
graph_module.graph.lint()
graph_module.recompile()
return graph_module
Import
from optimum.fx.optimization import Transformation
Validation Pattern
The following pattern, derived from Optimum's test suite (tests/fx/optimization/test_transformations.py), demonstrates how to validate transformations:
from transformers import BertModel, AutoTokenizer
from transformers.utils.fx import symbolic_trace
from optimum.fx.optimization import MergeLinears
import torch
# Set up original model and inputs
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
traced = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
inputs = tokenizer("This is a test.", return_tensors="pt")
# Get original output
original_output = model(**inputs)
# Apply transformation
transformation = MergeLinears()
transformed_model = transformation(traced)
# Inspect transformed nodes
transformed_nodes = transformation.get_transformed_nodes(transformed_model)
print(f"Number of nodes modified: {len(transformed_nodes)}")
for node in transformed_nodes:
print(f" Node: {node.name}, Op: {node.op}, Target: {node.target}")
# Validate numerical equivalence (only if preserves_computation is True)
if transformation.preserves_computation:
transformed_output = transformed_model(**inputs)
for orig, trans in zip(original_output, transformed_output):
if isinstance(orig, torch.Tensor):
assert torch.allclose(orig, trans, atol=1e-5), "Output mismatch detected!"
Signature Behavior Examples
From the test suite, the signature system has the following properties:
# Same class, same attributes -> same signature
t1 = DummyTransformation()
t2 = DummyTransformation()
assert t1.signature == t2.signature
# Same class, different attributes -> different signature
t1 = DummyTransformation()
t2 = DummyTransformation(some_argument=1)
assert t1.signature != t2.signature
# Different class, same attributes -> different signature
t2 = DummyTransformation(some_argument=1)
t3 = DifferentTransformation(some_argument=1)
assert t2.signature != t3.signature
Related
- implements -> Principle:Huggingface_Optimum_Transformation_Validation