Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Optimum ReversibleTransformation Reverse

From Leeroopedia

Overview

Implementation of the ReversibleTransformation base class and its concrete reverse() implementations that undo graph transformations applied to PyTorch FX graph modules.

Source

File: optimum/fx/optimization/transformations.py

ReversibleTransformation Base Class (L175-234)

class ReversibleTransformation(Transformation, ABC):
    @abstractmethod
    def reverse(self, graph_module: GraphModule) -> GraphModule:
        raise NotImplementedError("The reverse transform method needs to be implemented.")

    def __call__(
        self, graph_module: GraphModule, lint_and_recompile: bool = True, reverse: bool = False
    ) -> GraphModule:
        func = self.transform if not reverse else self.reverse
        graph_module = func(graph_module)
        if lint_and_recompile:
            graph_module.graph.lint()
            graph_module.recompile()
        return graph_module

    def mark_as_restored(self, node: Node):
        node_transformations = getattr(node, "transformations", set())
        if self.signature not in node_transformations:
            raise ValueError("The node was not transformed by this transformation.")
        node_transformations.remove(self.signature)
Method Line Range Description
reverse L184-195 Abstract method; must be implemented by subclasses to undo their transformation
__call__ L197-220 Dispatch: calls self.transform or self.reverse based on reverse parameter
mark_as_restored L222-234 Removes the transformation's signature from a node; raises ValueError if the node was not previously transformed

Dispatch Mechanism (L215)

The key line that determines forward vs. reverse execution:

func = self.transform if not reverse else self.reverse

This enables a unified calling interface where the same transformation object handles both application and reversal.

Concrete Reverse Implementations

MergeLinears.reverse (L386-389)

Splits merged linear layers back into their original individual linear layers.

def reverse(self, graph_module: GraphModule) -> GraphModule:
    for node in self.get_transformed_nodes(graph_module):
        self._unmerge_linears(graph_module, node, graph_module.get_submodule(node.target))
    return graph_module

The _unmerge_linears static method (L322-365) performs the actual restoration:

  1. Reads the original linear targets from merged_linear_node.linear_node_targets.
  2. Sorts output nodes by their slice start index to match the original concatenation order.
  3. Creates new individual nn.Linear modules with weights extracted from the merged linear.
  4. Restores each output node's op to "call_module" and its target to the original module path.
  5. Deletes the merged linear module and erases its node from the graph.

FuseBiasInLinear.reverse (L429-443)

Restores the original bias and input connections for linear layers that had their bias fused into the weight.

def reverse(self, graph_module: GraphModule) -> GraphModule:
    for node in self.get_transformed_nodes(graph_module):
        node.args = (node.start_node,)
        n = node.end_node
        while n is not node.start_node:
            if n not in node.nodes_to_ignore:
                graph_module.graph.erase_node(n)
            n = n.prev
        self.mark_as_restored(node)
        module = graph_module.get_submodule(node.target)
        new_weight = torch.nn.Parameter(module.weight[:, :-1])
        new_bias = torch.nn.Parameter(module.weight[:, -1].squeeze())
        module.weight = new_weight
        module.bias = new_bias
    return graph_module

How it works:

  1. Restores the original input connection by setting node.args back to (node.start_node,).
  2. Erases the inserted concatenation nodes (walking backward from end_node to start_node, skipping nodes that existed before the transformation).
  3. Removes the transformation's signature via mark_as_restored.
  4. Extracts the original weight (all columns except the last) and bias (the last column) from the extended weight matrix.

ChangeTrueDivToMulByInverse.reverse (L467-474)

Restores division operations that were converted to multiplication by inverse.

def reverse(self, graph_module: GraphModule) -> GraphModule:
    for node in self.get_transformed_nodes(graph_module):
        node.target = operator.truediv
        x, y = node.args
        node.args = (x, 1 / y)
        self.mark_as_restored(node)
    return graph_module

How it works:

  1. Changes node.target from operator.mul back to operator.truediv.
  2. Inverts the constant argument back (from 1/c to c).
  3. Marks the node as restored.

Import

from optimum.fx.optimization import ReversibleTransformation

Usage Example

from transformers import BertModel
from transformers.utils.fx import symbolic_trace
from optimum.fx.optimization import MergeLinears

model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
traced = symbolic_trace(
    model,
    input_names=["input_ids", "attention_mask", "token_type_ids"],
)

# Apply transformation
transformation = MergeLinears()
transformed_model = transformation(traced)

# Reverse transformation using the same instance
restored_model = transformation(transformed_model, reverse=True)

# Verify: original parameters should match restored parameters
orig_params = dict(model.named_parameters())
restored_params = dict(restored_model.named_parameters())
assert set(orig_params.keys()) == set(restored_params.keys())
for name in orig_params:
    assert torch.allclose(orig_params[name], restored_params[name])

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment