Implementation:Huggingface Optimum ReversibleTransformation Reverse
Overview
Implementation of the ReversibleTransformation base class and its concrete reverse() implementations that undo graph transformations applied to PyTorch FX graph modules.
Source
File: optimum/fx/optimization/transformations.py
ReversibleTransformation Base Class (L175-234)
class ReversibleTransformation(Transformation, ABC):
@abstractmethod
def reverse(self, graph_module: GraphModule) -> GraphModule:
raise NotImplementedError("The reverse transform method needs to be implemented.")
def __call__(
self, graph_module: GraphModule, lint_and_recompile: bool = True, reverse: bool = False
) -> GraphModule:
func = self.transform if not reverse else self.reverse
graph_module = func(graph_module)
if lint_and_recompile:
graph_module.graph.lint()
graph_module.recompile()
return graph_module
def mark_as_restored(self, node: Node):
node_transformations = getattr(node, "transformations", set())
if self.signature not in node_transformations:
raise ValueError("The node was not transformed by this transformation.")
node_transformations.remove(self.signature)
| Method | Line Range | Description |
|---|---|---|
reverse |
L184-195 | Abstract method; must be implemented by subclasses to undo their transformation |
__call__ |
L197-220 | Dispatch: calls self.transform or self.reverse based on reverse parameter
|
mark_as_restored |
L222-234 | Removes the transformation's signature from a node; raises ValueError if the node was not previously transformed
|
Dispatch Mechanism (L215)
The key line that determines forward vs. reverse execution:
func = self.transform if not reverse else self.reverse
This enables a unified calling interface where the same transformation object handles both application and reversal.
Concrete Reverse Implementations
MergeLinears.reverse (L386-389)
Splits merged linear layers back into their original individual linear layers.
def reverse(self, graph_module: GraphModule) -> GraphModule:
for node in self.get_transformed_nodes(graph_module):
self._unmerge_linears(graph_module, node, graph_module.get_submodule(node.target))
return graph_module
The _unmerge_linears static method (L322-365) performs the actual restoration:
- Reads the original linear targets from
merged_linear_node.linear_node_targets. - Sorts output nodes by their slice start index to match the original concatenation order.
- Creates new individual
nn.Linearmodules with weights extracted from the merged linear. - Restores each output node's
opto"call_module"and itstargetto the original module path. - Deletes the merged linear module and erases its node from the graph.
FuseBiasInLinear.reverse (L429-443)
Restores the original bias and input connections for linear layers that had their bias fused into the weight.
def reverse(self, graph_module: GraphModule) -> GraphModule:
for node in self.get_transformed_nodes(graph_module):
node.args = (node.start_node,)
n = node.end_node
while n is not node.start_node:
if n not in node.nodes_to_ignore:
graph_module.graph.erase_node(n)
n = n.prev
self.mark_as_restored(node)
module = graph_module.get_submodule(node.target)
new_weight = torch.nn.Parameter(module.weight[:, :-1])
new_bias = torch.nn.Parameter(module.weight[:, -1].squeeze())
module.weight = new_weight
module.bias = new_bias
return graph_module
How it works:
- Restores the original input connection by setting
node.argsback to(node.start_node,). - Erases the inserted concatenation nodes (walking backward from
end_nodetostart_node, skipping nodes that existed before the transformation). - Removes the transformation's signature via
mark_as_restored. - Extracts the original weight (all columns except the last) and bias (the last column) from the extended weight matrix.
ChangeTrueDivToMulByInverse.reverse (L467-474)
Restores division operations that were converted to multiplication by inverse.
def reverse(self, graph_module: GraphModule) -> GraphModule:
for node in self.get_transformed_nodes(graph_module):
node.target = operator.truediv
x, y = node.args
node.args = (x, 1 / y)
self.mark_as_restored(node)
return graph_module
How it works:
- Changes
node.targetfromoperator.mulback tooperator.truediv. - Inverts the constant argument back (from
1/ctoc). - Marks the node as restored.
Import
from optimum.fx.optimization import ReversibleTransformation
Usage Example
from transformers import BertModel
from transformers.utils.fx import symbolic_trace
from optimum.fx.optimization import MergeLinears
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
traced = symbolic_trace(
model,
input_names=["input_ids", "attention_mask", "token_type_ids"],
)
# Apply transformation
transformation = MergeLinears()
transformed_model = transformation(traced)
# Reverse transformation using the same instance
restored_model = transformation(transformed_model, reverse=True)
# Verify: original parameters should match restored parameters
orig_params = dict(model.named_parameters())
restored_params = dict(restored_model.named_parameters())
assert set(orig_params.keys()) == set(restored_params.keys())
for name in orig_params:
assert torch.allclose(orig_params[name], restored_params[name])
Related
- implements -> Principle:Huggingface_Optimum_Reversible_Transformation