Implementation:AUTOMATIC1111 Stable diffusion webui CondFunc
| Knowledge Sources | |
|---|---|
| Domains | Metaprogramming, Conditional_Hijacking |
| Last Updated | 2025-05-15 00:00 GMT |
Overview
Implements a conditional function hijacking mechanism that replaces a target function with a substitute that is only invoked when a runtime condition is met, falling back to the original otherwise.
Description
The CondFunc module provides the CondFunc class, which implements conditional monkey patching. When instantiated, it resolves a target function (specified either as a callable or a dot-separated string path like "torch.nn.functional.linear"), replaces it on its parent module with a conditional dispatcher, and returns a callable wrapper. At call time, the wrapper evaluates a condition function (cond_func) with the original function and all arguments. If the condition returns True (or if no condition function is provided, as the default is always_true_func), the substitute function (sub_func) is called with the original function as its first argument plus all other arguments. If the condition returns False, the original function is called directly. The string-based resolution mechanism iterates from the deepest module path backwards to find the importable module, then traverses attributes to locate the target function. This is used extensively for device-specific and platform-specific hijacks throughout the WebUI.
Usage
Use CondFunc when you need to conditionally replace a function in a third-party library (such as PyTorch) based on runtime conditions like device type, tensor properties, or platform-specific behavior.
Code Reference
Source Location
- Repository: AUTOMATIC1111_Stable_diffusion_webui
- File: modules/sd_hijack_utils.py
- Lines: 1-36
Signature
class CondFunc:
def __new__(cls, orig_func, sub_func, cond_func=always_true_func) -> callable
def __init__(self, orig_func, sub_func, cond_func) -> None
def __call__(self, *args, **kwargs) -> any
Import
from modules.sd_hijack_utils import CondFunc
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| orig_func | str or callable | Yes | The target function to hijack, either as a callable or a dot-separated module path string |
| sub_func | callable | Yes | The substitute function; receives the original function as its first argument followed by *args and **kwargs |
| cond_func | callable | No | A condition function that receives the original function and call arguments; returns True to use sub_func, False for orig_func (default: always True) |
Outputs
| Name | Type | Description |
|---|---|---|
| wrapper | callable | A callable that conditionally dispatches to either the substitute or original function at each invocation |
Usage Examples
from modules.sd_hijack_utils import CondFunc
# Unconditionally replace a function
CondFunc(
'torch.nn.functional.linear',
lambda orig_func, input, weight, bias=None: custom_linear(input, weight, bias)
)
# Conditionally replace based on device
CondFunc(
'torch.layer_norm',
lambda orig_func, *args, **kwargs: orig_func(*([a.float() for a in args]), **kwargs),
lambda orig_func, *args, **kwargs: args[0].device.type == 'mps'
)