Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Optimum Store Input Hook

From Leeroopedia

Overview

A closure-based forward pre-hook that captures input activations to a transformer block and halts the forward pass by raising a ValueError. This is a Pattern Doc — the hook is defined as an inner function within GPTQQuantizer.quantize_model().

Source

File: optimum/gptq/quantizer.py Lines: 499-523

Pattern

def store_input_hook(module, args, kwargs):
    layer_input: List[torch.Tensor] = []
    if kwargs.get("hidden_states") is not None:
        layer_input.append(nested_move_to(kwargs["hidden_states"], device=cur_layer_device))
    else:
        layer_input.append(nested_move_to(args[0], device=cur_layer_device))

    layer_inputs.append(layer_input)
    other_kwargs = {}
    for k, v in kwargs.items():  # make sure other arguments also be captured
        if k not in ["hidden_states"]:
            other_kwargs[k] = nested_move_to(v, cur_layer_device)
    layer_input_kwargs.append(other_kwargs)
    raise ValueError

Registration and Usage

The hook is registered and used within the quantize_model() method:

if self.cache_block_outputs:
    handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
    for data in dataset:
        for k, v in data.items():
            data[k] = nested_move_to(v, cur_layer_device)
        try:
            model(**data)
        except ValueError:
            pass
    handle.remove()

Behavior

The hook operates as a closure capturing several variables from the enclosing quantize_model() scope:

Captured Variable Type Purpose
layer_inputs List Accumulates the hidden state inputs for each calibration sample.
layer_input_kwargs List Accumulates keyword arguments (attention masks, position ids, etc.).
cur_layer_device torch.device Target device for captured tensors.

Step-by-step behavior:

  1. Extract hidden states — Checks if kwargs["hidden_states"] is present; if not, falls back to args[0]. This handles both standard and non-standard model architectures (e.g., models that pass hidden states as a keyword argument).
  2. Move to device — Uses nested_move_to() to recursively move all tensors to cur_layer_device.
  3. Store inputs — Appends the hidden state list to layer_inputs.
  4. Capture keyword arguments — Iterates over all kwargs except "hidden_states" (which is already captured) and stores them in layer_input_kwargs.
  5. Halt forward pass — Raises ValueError to prevent computation beyond the first block.

Cache Block Outputs Mode

When cache_block_outputs=True (default):

  • The hook is registered only on blocks[0] before the main quantization loop.
  • Inputs are captured once, then propagated through blocks sequentially.

When cache_block_outputs=False:

  • The hook is registered on the current block at the start of each iteration.
  • A full forward pass through preceding modules is performed for each block.
  • This mode supports non-standard architectures (e.g., ChatGLM) but is slower.

Helper Function

The nested_move_to() utility (from optimum/gptq/utils.py) recursively moves tensors to the target device, handling nested lists and tuples:

def nested_move_to(v, device):
    if isinstance(v, torch.Tensor):
        return move_to(v, device)
    elif isinstance(v, (list, tuple)):
        return type(v)([nested_move_to(e, device) for e in v])
    else:
        return v

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment