Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Optimum Get Transformed Nodes

From Leeroopedia

Overview

Implementation of the transformation tracking and validation mechanisms in the Transformation base class. These methods enable inspecting which nodes a transformation modified and verifying correctness.

Source

File: optimum/fx/optimization/transformations.py

APIs

get_transformed_nodes (L161-172)

Returns the list of graph nodes that were modified by this specific transformation instance.

def get_transformed_nodes(self, graph_module: GraphModule) -> List[Node]:
    return [node for node in graph_module.graph.nodes if self.transformed(node)]
Parameter Type Description
graph_module torch.fx.GraphModule The graph module to inspect

Returns: List[torch.fx.Node] -- Nodes that were transformed by this transformation instance.

transformed (L149-159)

Checks whether a specific node was transformed by this transformation instance.

def transformed(self, node: Node) -> bool:
    return self.signature in getattr(node, "transformations", set())

Returns: bool -- True if the node's transformations set contains this transformation's signature.

signature (L127-135)

Property that returns a unique hash identifying this transformation instance.

@property
def signature(self):
    attributes_to_use_for_hashing = vars(self)
    attributes_to_use_for_hashing[""] = self.__class__
    hash_str = "_".join(f"{k}_{hash(v)}" for k, v in attributes_to_use_for_hashing.items())
    return hash(hash_str)

Returns: int -- A hash value unique to the combination of this transformation's class and its instance attributes.

mark_as_transformed (L137-147)

Marks a node as having been transformed by this transformation.

def mark_as_transformed(self, node: Node):
    node_transformations = getattr(node, "transformations", set())
    node_transformations.add(self.signature)
    node.transformations = node_transformations

This method is called by concrete transformation implementations during their transform() method to record which nodes they have modified.

__call__ (L108-125)

The callable interface for applying the transformation. Handles linting and recompilation after the transformation is applied.

def __call__(self, graph_module: GraphModule, lint_and_recompile: bool = True) -> GraphModule:
    graph_module = self.transform(graph_module)
    if lint_and_recompile:
        graph_module.graph.lint()
        graph_module.recompile()
    return graph_module

Import

from optimum.fx.optimization import Transformation

Validation Pattern

The following pattern, derived from Optimum's test suite (tests/fx/optimization/test_transformations.py), demonstrates how to validate transformations:

from transformers import BertModel, AutoTokenizer
from transformers.utils.fx import symbolic_trace
from optimum.fx.optimization import MergeLinears
import torch

# Set up original model and inputs
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
traced = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
inputs = tokenizer("This is a test.", return_tensors="pt")

# Get original output
original_output = model(**inputs)

# Apply transformation
transformation = MergeLinears()
transformed_model = transformation(traced)

# Inspect transformed nodes
transformed_nodes = transformation.get_transformed_nodes(transformed_model)
print(f"Number of nodes modified: {len(transformed_nodes)}")
for node in transformed_nodes:
    print(f"  Node: {node.name}, Op: {node.op}, Target: {node.target}")

# Validate numerical equivalence (only if preserves_computation is True)
if transformation.preserves_computation:
    transformed_output = transformed_model(**inputs)
    for orig, trans in zip(original_output, transformed_output):
        if isinstance(orig, torch.Tensor):
            assert torch.allclose(orig, trans, atol=1e-5), "Output mismatch detected!"

Signature Behavior Examples

From the test suite, the signature system has the following properties:

# Same class, same attributes -> same signature
t1 = DummyTransformation()
t2 = DummyTransformation()
assert t1.signature == t2.signature

# Same class, different attributes -> different signature
t1 = DummyTransformation()
t2 = DummyTransformation(some_argument=1)
assert t1.signature != t2.signature

# Different class, same attributes -> different signature
t2 = DummyTransformation(some_argument=1)
t3 = DifferentTransformation(some_argument=1)
assert t2.signature != t3.signature

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment