Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Optimum MergeLinears

From Leeroopedia
Revision as of 13:04, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Huggingface_Optimum_MergeLinears.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)

Overview

Implementation of the FX graph transformation classes in the Optimum library. This module provides both the base transformation framework and concrete transformation implementations for optimizing traced PyTorch models.

Source

File: optimum/fx/optimization/transformations.py

Base Classes

Transformation (L84-172)

Abstract base class for all graph transformations.

class Transformation(ABC):
    preserves_computation: bool = False

    @abstractmethod
    def transform(self, graph_module: GraphModule) -> GraphModule:
        raise NotImplementedError("The transform method needs to be implemented.")

    def __call__(self, graph_module: GraphModule, lint_and_recompile: bool = True) -> GraphModule:
        graph_module = self.transform(graph_module)
        if lint_and_recompile:
            graph_module.graph.lint()
            graph_module.recompile()
        return graph_module

    @property
    def signature(self):
        attributes_to_use_for_hashing = vars(self)
        attributes_to_use_for_hashing[""] = self.__class__
        hash_str = "_".join(f"{k}_{hash(v)}" for k, v in attributes_to_use_for_hashing.items())
        return hash(hash_str)

    def mark_as_transformed(self, node: Node): ...
    def transformed(self, node: Node) -> bool: ...
    def get_transformed_nodes(self, graph_module: GraphModule) -> List[Node]: ...

ReversibleTransformation (L175-234)

Extension of Transformation that supports undoing applied changes.

class ReversibleTransformation(Transformation):
    @abstractmethod
    def reverse(self, graph_module: GraphModule) -> GraphModule:
        raise NotImplementedError("The reverse transform method needs to be implemented.")

    def __call__(self, graph_module: GraphModule, lint_and_recompile: bool = True, reverse: bool = False) -> GraphModule:
        func = self.transform if not reverse else self.reverse
        graph_module = func(graph_module)
        if lint_and_recompile:
            graph_module.graph.lint()
            graph_module.recompile()
        return graph_module

    def mark_as_restored(self, node: Node): ...

Concrete Transformations

MergeLinears (L237-389)

Merges parallel linear layers (e.g., Q/K/V projections in attention) that share the same input into one larger nn.Linear.

class MergeLinears(ReversibleTransformation):
    preserves_computation = True

    def transform(self, graph_module: GraphModule) -> GraphModule:
        # Finds all call_module nodes that are nn.Linear
        # Groups them by shared input node
        # Merges groups with >1 linear into a single larger linear
        # Replaces individual outputs with slicing from the merged output
        ...

    def reverse(self, graph_module: GraphModule) -> GraphModule:
        for node in self.get_transformed_nodes(graph_module):
            self._unmerge_linears(graph_module, node, graph_module.get_submodule(node.target))
        return graph_module

How it works:

  1. Scans the graph for call_module nodes that wrap nn.Linear.
  2. Groups linear nodes by their shared input node (e.g., all Q/K/V projections fed by the same hidden state).
  3. For each group with more than one linear, creates a new merged nn.Linear with concatenated weights and biases.
  4. Replaces the original linear nodes with slice operations (operator.getitem) on the merged output.
  5. Stores the original split sizes for reversibility.

FuseBiasInLinear (L393-443)

Fuses bias into the weight matrix of nn.Linear layers by appending a column of ones to the input and a column of bias values to the weight.

class FuseBiasInLinear(ReversibleTransformation):
    preserves_computation = True

    def transform(self, graph_module: GraphModule) -> GraphModule:
        # For each nn.Linear with bias:
        #   Inserts concat([input, ones], dim=-1) before the linear
        #   Extends weight matrix: new_weight = cat([weight, bias[:, None]], dim=1)
        #   Sets bias to None
        ...

    def reverse(self, graph_module: GraphModule) -> GraphModule:
        # Restores original input connections
        # Splits weight back into weight and bias
        ...

ChangeTrueDivToMulByInverse (L447-474)

Replaces truediv operations with constant denominators by multiplication with the inverse.

class ChangeTrueDivToMulByInverse(ReversibleTransformation):
    preserves_computation = True

    def transform(self, graph_module: GraphModule) -> GraphModule:
        graph = graph_module.graph
        for node in graph.nodes:
            if node.op == "call_function" and node.target == operator.truediv:
                x, y = node.args
                if not isinstance(y, torch.fx.Node):
                    node.target = operator.mul
                    node.args = (x, 1 / y)
                    self.mark_as_transformed(node)
        return graph_module

    def reverse(self, graph_module: GraphModule) -> GraphModule:
        for node in self.get_transformed_nodes(graph_module):
            node.target = operator.truediv
            x, y = node.args
            node.args = (x, 1 / y)
            self.mark_as_restored(node)
        return graph_module

FuseBatchNorm2dInConv2d (L478-557)

Folds nn.BatchNorm2d into the preceding nn.Conv2d. Irreversible -- extends Transformation, not ReversibleTransformation.

class FuseBatchNorm2dInConv2d(Transformation):
    preserves_computation = True

    def transform(self, graph_module: GraphModule) -> GraphModule:
        # Finds BatchNorm2d nodes preceded by Conv2d nodes
        # Fuses BN parameters into Conv2d weight and bias
        # Removes the BatchNorm2d node from the graph
        ...

FuseBatchNorm1dInLinear (L561-682)

Folds nn.BatchNorm1d into the preceding or following nn.Linear. Handles both Linear -> BN1d and BN1d -> Linear patterns. Irreversible.

class FuseBatchNorm1dInLinear(Transformation):
    preserves_computation = True

    def transform(self, graph_module: GraphModule) -> GraphModule:
        # Handles two patterns:
        #   1. nn.Linear -> nn.BatchNorm1d
        #   2. nn.BatchNorm1d -> nn.Linear
        # Fuses BN parameters into Linear weight and bias
        # Removes the BatchNorm1d node from the graph
        ...

Import

from optimum.fx.optimization import (
    MergeLinears,
    FuseBiasInLinear,
    ChangeTrueDivToMulByInverse,
    FuseBatchNorm2dInConv2d,
    FuseBatchNorm1dInLinear,
    Transformation,
    ReversibleTransformation,
)

Transformation Hierarchy Summary

Class Line Range Parent Reversible Preserves Computation
Transformation L84-172 ABC N/A (abstract) False (default)
ReversibleTransformation L175-234 Transformation N/A (abstract) Inherited
MergeLinears L237-389 ReversibleTransformation Yes Yes
FuseBiasInLinear L393-443 ReversibleTransformation Yes Yes
ChangeTrueDivToMulByInverse L447-474 ReversibleTransformation Yes Yes
FuseBatchNorm2dInConv2d L478-557 Transformation No Yes
FuseBatchNorm1dInLinear L561-682 Transformation No Yes

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment