Implementation:Huggingface Optimum Store Input Hook
Overview
A closure-based forward pre-hook that captures input activations to a transformer block and halts the forward pass by raising a ValueError. This is a Pattern Doc — the hook is defined as an inner function within GPTQQuantizer.quantize_model().
Source
File: optimum/gptq/quantizer.py Lines: 499-523
Pattern
def store_input_hook(module, args, kwargs):
layer_input: List[torch.Tensor] = []
if kwargs.get("hidden_states") is not None:
layer_input.append(nested_move_to(kwargs["hidden_states"], device=cur_layer_device))
else:
layer_input.append(nested_move_to(args[0], device=cur_layer_device))
layer_inputs.append(layer_input)
other_kwargs = {}
for k, v in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states"]:
other_kwargs[k] = nested_move_to(v, cur_layer_device)
layer_input_kwargs.append(other_kwargs)
raise ValueError
Registration and Usage
The hook is registered and used within the quantize_model() method:
if self.cache_block_outputs:
handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
for data in dataset:
for k, v in data.items():
data[k] = nested_move_to(v, cur_layer_device)
try:
model(**data)
except ValueError:
pass
handle.remove()
Behavior
The hook operates as a closure capturing several variables from the enclosing quantize_model() scope:
| Captured Variable | Type | Purpose |
|---|---|---|
layer_inputs |
List |
Accumulates the hidden state inputs for each calibration sample. |
layer_input_kwargs |
List |
Accumulates keyword arguments (attention masks, position ids, etc.). |
cur_layer_device |
torch.device |
Target device for captured tensors. |
Step-by-step behavior:
- Extract hidden states — Checks if
kwargs["hidden_states"]is present; if not, falls back toargs[0]. This handles both standard and non-standard model architectures (e.g., models that pass hidden states as a keyword argument). - Move to device — Uses
nested_move_to()to recursively move all tensors tocur_layer_device. - Store inputs — Appends the hidden state list to
layer_inputs. - Capture keyword arguments — Iterates over all
kwargsexcept"hidden_states"(which is already captured) and stores them inlayer_input_kwargs. - Halt forward pass — Raises
ValueErrorto prevent computation beyond the first block.
Cache Block Outputs Mode
When cache_block_outputs=True (default):
- The hook is registered only on
blocks[0]before the main quantization loop. - Inputs are captured once, then propagated through blocks sequentially.
When cache_block_outputs=False:
- The hook is registered on the current block at the start of each iteration.
- A full forward pass through preceding modules is performed for each block.
- This mode supports non-standard architectures (e.g., ChatGLM) but is slower.
Helper Function
The nested_move_to() utility (from optimum/gptq/utils.py) recursively moves tensors to the target device, handling nested lists and tuples:
def nested_move_to(v, device):
if isinstance(v, torch.Tensor):
return move_to(v, device)
elif isinstance(v, (list, tuple)):
return type(v)([nested_move_to(e, device) for e in v])
else:
return v