Implementation:Huggingface Optimum MergeLinears
Overview
Implementation of the FX graph transformation classes in the Optimum library. This module provides both the base transformation framework and concrete transformation implementations for optimizing traced PyTorch models.
Source
File: optimum/fx/optimization/transformations.py
Base Classes
Transformation (L84-172)
Abstract base class for all graph transformations.
class Transformation(ABC):
preserves_computation: bool = False
@abstractmethod
def transform(self, graph_module: GraphModule) -> GraphModule:
raise NotImplementedError("The transform method needs to be implemented.")
def __call__(self, graph_module: GraphModule, lint_and_recompile: bool = True) -> GraphModule:
graph_module = self.transform(graph_module)
if lint_and_recompile:
graph_module.graph.lint()
graph_module.recompile()
return graph_module
@property
def signature(self):
attributes_to_use_for_hashing = vars(self)
attributes_to_use_for_hashing[""] = self.__class__
hash_str = "_".join(f"{k}_{hash(v)}" for k, v in attributes_to_use_for_hashing.items())
return hash(hash_str)
def mark_as_transformed(self, node: Node): ...
def transformed(self, node: Node) -> bool: ...
def get_transformed_nodes(self, graph_module: GraphModule) -> List[Node]: ...
ReversibleTransformation (L175-234)
Extension of Transformation that supports undoing applied changes.
class ReversibleTransformation(Transformation):
@abstractmethod
def reverse(self, graph_module: GraphModule) -> GraphModule:
raise NotImplementedError("The reverse transform method needs to be implemented.")
def __call__(self, graph_module: GraphModule, lint_and_recompile: bool = True, reverse: bool = False) -> GraphModule:
func = self.transform if not reverse else self.reverse
graph_module = func(graph_module)
if lint_and_recompile:
graph_module.graph.lint()
graph_module.recompile()
return graph_module
def mark_as_restored(self, node: Node): ...
Concrete Transformations
MergeLinears (L237-389)
Merges parallel linear layers (e.g., Q/K/V projections in attention) that share the same input into one larger nn.Linear.
class MergeLinears(ReversibleTransformation):
preserves_computation = True
def transform(self, graph_module: GraphModule) -> GraphModule:
# Finds all call_module nodes that are nn.Linear
# Groups them by shared input node
# Merges groups with >1 linear into a single larger linear
# Replaces individual outputs with slicing from the merged output
...
def reverse(self, graph_module: GraphModule) -> GraphModule:
for node in self.get_transformed_nodes(graph_module):
self._unmerge_linears(graph_module, node, graph_module.get_submodule(node.target))
return graph_module
How it works:
- Scans the graph for
call_modulenodes that wrapnn.Linear. - Groups linear nodes by their shared input node (e.g., all Q/K/V projections fed by the same hidden state).
- For each group with more than one linear, creates a new merged
nn.Linearwith concatenated weights and biases. - Replaces the original linear nodes with slice operations (
operator.getitem) on the merged output. - Stores the original split sizes for reversibility.
FuseBiasInLinear (L393-443)
Fuses bias into the weight matrix of nn.Linear layers by appending a column of ones to the input and a column of bias values to the weight.
class FuseBiasInLinear(ReversibleTransformation):
preserves_computation = True
def transform(self, graph_module: GraphModule) -> GraphModule:
# For each nn.Linear with bias:
# Inserts concat([input, ones], dim=-1) before the linear
# Extends weight matrix: new_weight = cat([weight, bias[:, None]], dim=1)
# Sets bias to None
...
def reverse(self, graph_module: GraphModule) -> GraphModule:
# Restores original input connections
# Splits weight back into weight and bias
...
ChangeTrueDivToMulByInverse (L447-474)
Replaces truediv operations with constant denominators by multiplication with the inverse.
class ChangeTrueDivToMulByInverse(ReversibleTransformation):
preserves_computation = True
def transform(self, graph_module: GraphModule) -> GraphModule:
graph = graph_module.graph
for node in graph.nodes:
if node.op == "call_function" and node.target == operator.truediv:
x, y = node.args
if not isinstance(y, torch.fx.Node):
node.target = operator.mul
node.args = (x, 1 / y)
self.mark_as_transformed(node)
return graph_module
def reverse(self, graph_module: GraphModule) -> GraphModule:
for node in self.get_transformed_nodes(graph_module):
node.target = operator.truediv
x, y = node.args
node.args = (x, 1 / y)
self.mark_as_restored(node)
return graph_module
FuseBatchNorm2dInConv2d (L478-557)
Folds nn.BatchNorm2d into the preceding nn.Conv2d. Irreversible -- extends Transformation, not ReversibleTransformation.
class FuseBatchNorm2dInConv2d(Transformation):
preserves_computation = True
def transform(self, graph_module: GraphModule) -> GraphModule:
# Finds BatchNorm2d nodes preceded by Conv2d nodes
# Fuses BN parameters into Conv2d weight and bias
# Removes the BatchNorm2d node from the graph
...
FuseBatchNorm1dInLinear (L561-682)
Folds nn.BatchNorm1d into the preceding or following nn.Linear. Handles both Linear -> BN1d and BN1d -> Linear patterns. Irreversible.
class FuseBatchNorm1dInLinear(Transformation):
preserves_computation = True
def transform(self, graph_module: GraphModule) -> GraphModule:
# Handles two patterns:
# 1. nn.Linear -> nn.BatchNorm1d
# 2. nn.BatchNorm1d -> nn.Linear
# Fuses BN parameters into Linear weight and bias
# Removes the BatchNorm1d node from the graph
...
Import
from optimum.fx.optimization import (
MergeLinears,
FuseBiasInLinear,
ChangeTrueDivToMulByInverse,
FuseBatchNorm2dInConv2d,
FuseBatchNorm1dInLinear,
Transformation,
ReversibleTransformation,
)
Transformation Hierarchy Summary
| Class | Line Range | Parent | Reversible | Preserves Computation |
|---|---|---|---|---|
Transformation |
L84-172 | ABC |
N/A (abstract) | False (default)
|
ReversibleTransformation |
L175-234 | Transformation |
N/A (abstract) | Inherited |
MergeLinears |
L237-389 | ReversibleTransformation |
Yes | Yes |
FuseBiasInLinear |
L393-443 | ReversibleTransformation |
Yes | Yes |
ChangeTrueDivToMulByInverse |
L447-474 | ReversibleTransformation |
Yes | Yes |
FuseBatchNorm2dInConv2d |
L478-557 | Transformation |
No | Yes |
FuseBatchNorm1dInLinear |
L561-682 | Transformation |
No | Yes |
Related
- implements -> Principle:Huggingface_Optimum_FX_Graph_Transformation