Implementation:Huggingface Optimum Compose
Overview
The compose function creates a single composite transformation from multiple individual transformations, applying them sequentially. Located in optimum/fx/optimization/transformations.py.
Source
File: optimum/fx/optimization/transformations.py L721-801
Signature
def compose(*args: Transformation, inplace: bool = True) -> Transformation:
| Parameter | Type | Default | Description |
|---|---|---|---|
*args |
Transformation |
(required) | The transformations to compose together, applied left-to-right |
inplace |
bool |
True |
Whether the composition modifies the graph in-place or creates a deep copy first |
Returns: A Transformation (or ReversibleTransformation if all inputs are reversible) that applies all given transformations in sequence.
Internal Implementation
The compose function uses functools.reduce to chain the transformations together. Internally, the list of transformations is reversed so that reduce builds the composition in the correct left-to-right application order.
def compose(*args: Transformation, inplace: bool = True) -> Transformation:
transformations = list(reversed(args))
composition_preserves_computation = all(t.preserves_computation for t in transformations)
composition_is_reversible = all(isinstance(t, ReversibleTransformation) for t in transformations)
if not inplace:
transformations.append(DeepCopy())
if not composition_is_reversible:
# Creates a Transformation subclass with _composition via functools.reduce
class ComposeTransformation(Transformation):
preserves_computation = composition_preserves_computation
_composition = functools.reduce(reduce_fn, transformations)
def transform(self, graph_module):
return ComposeTransformation._composition(graph_module)
else:
# Creates a ReversibleTransformation subclass with both _composition and _reverse_composition
class ComposeTransformation(ReversibleTransformation):
preserves_computation = composition_preserves_computation
_composition = functools.reduce(make_reduce_fn(False), transformations)
_reverse_composition = functools.reduce(make_reduce_fn(True), reversed(transformations))
def transform(self, graph_module):
return ComposeTransformation._composition(graph_module)
def reverse(self, graph_module):
return ComposeTransformation._reverse_composition(graph_module)
return ComposeTransformation()
Key Design Details
lint_and_recompile=Falseis passed to intermediate transformations within the composition, avoiding redundant linting and recompilation between steps. The final__call__handles linting and recompilation once.- Reverse composition uses
reversed(transformations)to apply reversal in the opposite order of the forward pass. DeepCopy(L685-702) is prepended wheninplace=Falseto protect the original graph module.
Supporting Internal Classes
DeepCopy (L685-702)
class DeepCopy(ReversibleTransformation):
preserves_computation = True
def transform(self, graph_module: GraphModule) -> GraphModule:
clone = copy.deepcopy(graph_module)
# Copies transformation metadata that deepcopy does not handle
for n1, n2 in zip(graph_module.graph.nodes, clone.graph.nodes):
if hasattr(n1, "transformations"):
n2.transformations = n1.transformations
return clone
def reverse(self, graph_module: GraphModule) -> GraphModule:
return self.transform(graph_module)
LintAndRecompile (L705-718)
class LintAndRecompile(ReversibleTransformation):
preserves_computation = True
def transform(self, graph_module: GraphModule) -> GraphModule:
graph_module.graph.lint()
graph_module.recompile()
return graph_module
def reverse(self, graph_module: GraphModule) -> GraphModule:
return self.transform(graph_module)
Import
from optimum.fx.optimization import compose
Usage Example
from optimum.fx.optimization import compose, MergeLinears, FuseBiasInLinear, ChangeTrueDivToMulByInverse
# Compose multiple optimizations into a single pipeline
optimization = compose(
ChangeTrueDivToMulByInverse(),
MergeLinears(),
FuseBiasInLinear(),
inplace=False, # Don't modify the original graph module
)
# Apply all transformations at once
optimized_model = optimization(traced_model)
# Reverse all transformations (works because all three are reversible)
restored_model = optimization(optimized_model, reverse=True)
A composition that includes a non-reversible transformation:
from optimum.fx.optimization import compose, MergeLinears, FuseBatchNorm2dInConv2d
# This composition is NOT reversible because FuseBatchNorm2dInConv2d is irreversible
optimization = compose(
MergeLinears(),
FuseBatchNorm2dInConv2d(),
)
optimized_model = optimization(traced_model)
# optimization(optimized_model, reverse=True) would raise an error
Related
- implements -> Principle:Huggingface_Optimum_Transformation_Composition