Principle:Huggingface Optimum Reversible Transformation
Overview
Extension of graph transformations that supports undoing applied changes to restore the original model graph.
Description
Some graph transformations can be reversed, restoring the graph to its pre-transformation state. The ReversibleTransformation class extends Transformation with a reverse() method that undoes the changes applied by transform().
Why Reversibility Matters
Reversible transformations are useful for several scenarios:
- Experimentation -- Try an optimization, benchmark it, then undo if not beneficial. This avoids the need to re-trace the model from scratch.
- Debugging -- Isolate which transformation causes issues by selectively applying and reverting individual transformations.
- Export -- Apply optimizations for inference, then reverse for continued training. Some optimizations (like merging Q/K/V projections) are beneficial for inference but undesirable during training.
- A/B testing -- Compare model behavior with and without specific optimizations using the same traced graph.
Reversible vs. Irreversible Transformations
| Transformation | Reversible | Reason |
|---|---|---|
| MergeLinears | Yes | Stores original split sizes; can reconstruct individual linear layers from the merged one |
| FuseBiasInLinear | Yes | Stores references to inserted nodes; can extract bias back from the extended weight matrix |
| ChangeTrueDivToMulByInverse | Yes | Simply swaps operator.truediv/operator.mul and inverts the constant
|
| FuseBatchNorm2dInConv2d | No | Destructively combines BN parameters into Conv2d weights; BN running statistics are lost |
| FuseBatchNorm1dInLinear | No | Destructively combines BN parameters into Linear weights; BN running statistics are lost |
Dispatch Mechanism
The ReversibleTransformation.__call__ method accepts a reverse parameter. When reverse=True, it dispatches to the reverse() method instead of transform(). This provides a clean, unified calling interface:
# Forward transformation
transformed = transformation(model)
# Reverse transformation (same callable, different flag)
restored = transformation(transformed, reverse=True)
Usage
Use when you need the ability to undo graph optimizations. The reverse=True parameter on __call__ provides a clean interface for applying the reverse transformation.
Integration with Composition
When all transformations passed to compose() are reversible, the composed result is also reversible. The reversal applies transformations in reverse order, correctly unwinding the chain of modifications.
Theoretical Basis
Bijective transformations on computation graphs. A transformation T is reversible if for every transformed graph G' = T(G), there exists T^(-1) such that T^(-1)(G') = G.
This requires the transformation to be information-preserving: any information that is needed to reconstruct the original graph must be stored during the forward transformation.
| Transformation | Information Preserved | How |
|---|---|---|
| MergeLinears | Original linear targets and split sizes | Stored as linear_node_targets attribute on the merged node; sizes derived from slice operations
|
| FuseBiasInLinear | Original input connections and inserted nodes | Stored as start_node, end_node, and nodes_to_ignore attributes on the node
|
| ChangeTrueDivToMulByInverse | Original operation and constant | Trivially recoverable by inverting the constant and swapping the operator |
Batch norm fusion is irreversible because it destructively combines batch normalization parameters (running mean, running variance, weight, bias) into the preceding convolution or linear layer's weight and bias. The original BN module is deleted from the graph, and its running statistics cannot be recovered from the fused parameters.
Restoration Tracking
The mark_as_restored method removes the transformation's signature from a node's transformations set, effectively recording that the node has been returned to its original state. This complements mark_as_transformed and ensures that the tracking system accurately reflects the current state of each node.
Related
- implemented_by -> Implementation:Huggingface_Optimum_ReversibleTransformation_Reverse