Principle:Huggingface Optimum Transformation Validation
Overview
Process of verifying that graph transformations produced correct results by inspecting transformed nodes and comparing model outputs.
Description
After applying transformations, validation ensures correctness. Two complementary mechanisms are provided:
1. Node Inspection via get_transformed_nodes
The get_transformed_nodes method returns the list of graph nodes that were modified by a specific transformation instance. This enables:
- Inspection -- Examine exactly which nodes were changed and how.
- Auditing -- Verify that the transformation only modified the expected nodes.
- Debugging -- Identify unexpected modifications when a transformation produces incorrect results.
2. Output Comparison
For transformations where preserves_computation=True, output comparison with torch.allclose verifies that the transformed model produces the same outputs as the original. This provides an end-to-end correctness check that catches errors the node-level inspection might miss.
| Validation Method | Scope | When to Use |
|---|---|---|
get_transformed_nodes |
Individual nodes | Inspect what changed; debug unexpected behavior |
| Output comparison | Full model | Verify numerical equivalence after transformation |
| Combined | Both | Production validation: verify both structure and outputs |
Usage
Use after applying any graph transformation to verify correctness. Validation is especially important when:
- Developing new transformations -- Ensure new transformations do not break model behavior.
- Applying transformations to new model architectures -- Confirm that the transformation handles the model's specific graph patterns correctly.
- Chaining multiple transformations -- Verify that the combination of transformations produces correct results, even if each individual transformation is correct.
Theoretical Basis
Signature-Based Marking
Each Transformation instance has a unique signature derived from its class name and instance attributes. The signature is computed as:
@property
def signature(self):
attributes_to_use_for_hashing = vars(self)
attributes_to_use_for_hashing[""] = self.__class__
hash_str = "_".join(f"{k}_{hash(v)}" for k, v in attributes_to_use_for_hashing.items())
return hash(hash_str)
When transform() modifies a node, it calls mark_as_transformed(node), which adds the signature to the node's transformations set. The get_transformed_nodes method then filters nodes that contain this signature.
Key properties of the signature system:
- Instance-specific -- Two instances of the same transformation class with different constructor arguments have different signatures.
- Class-specific -- Two instances of different transformation classes (even if they have the same attributes) have different signatures.
- Deterministic -- The same instance always produces the same signature.
Numerical Validation
Output comparison uses element-wise comparison with tolerance:
torch.allclose(original_output, transformed_output, atol=tolerance)
This accounts for floating-point precision differences that may arise from reordering operations (e.g., fusing bias into weights changes the order of additions, which can produce slightly different results due to floating-point non-associativity).
Related
- implemented_by -> Implementation:Huggingface_Optimum_Get_Transformed_Nodes