Overview
ExLlamaV2Module is the abstract base class for all model submodules in ExLlamaV2, defining the standard interface for weight loading, device placement, forward passes, and memory footprint estimation, while Intervention wraps any module to inject pre/post-forward hooks.
Description
ExLlamaV2Module establishes the contract that every layer type (attention, MLP, MoE, embeddings, norms, linear projections, etc.) must implement. It stores a reference to the parent ExLlamaV2 model instance, a key string that identifies the module's position in the weight file namespace, an optional alt_key for alternative naming schemes, a device_idx for device placement, and a list of submodules.
The abstract interface includes:
- load() / unload(): Load weights from safetensors files onto the target device, or release them.
- forward(hidden_states, cache, attn_params, past_len, intermediates, loras): Run the module's computation.
- numel(): Return the total number of weight elements.
- scratch_space_fixed() / scratch_space() / scratch_space_tp(): Report scratch memory requirements.
Concrete implementations are provided for weight loading:
- load_multi(key, keys, measure, cpu): Loads multiple sub-tensors from safetensors files (e.g. "q_weight", "q_scale", etc.) in a single pass, grouping by source file for efficiency.
- load_weight(override_key, cpu): Detects the quantization format (EXL2, GPTQ, or plain torch) and loads the appropriate set of tensors. For EXL2 it loads q_weight/q_invperm/q_scale/q_scale_max/q_groups/q_perm; for GPTQ it loads qweight/qzeros/scales/g_idx; for torch it loads weight (and optionally bias).
- load_weight_fused(f_key, f_beg, f_end, in_feat, out_feat, altpack_qkv): Loads a slice of a fused weight tensor (e.g. when QKV projections are stored as a single matrix), handling transposition and alternative QKV packing layouts.
- weight_footprint(): Measures the total byte size of the module's weights without actually loading them (uses the measure=True path in load_multi).
Intervention is a wrapper subclass that delegates all operations to an inner module but intercepts forward() to call optional pre_forward and post_forward callback functions. These hooks receive the hidden states and all forward arguments, allowing external code to inspect or modify activations at any layer boundary. The hooks' output is automatically moved back to the original device.
Usage
ExLlamaV2Module is never instantiated directly. Subclasses such as ExLlamaV2Attention, ExLlamaV2MLP, ExLlamaV2MoEMLP, ExLlamaV2RMSNorm, ExLlamaV2Linear, and ExLlamaV2Embedding implement the abstract methods. Use Intervention to wrap any module when you need to add activation probing, steering, or logging hooks.
Code Reference
Source Location
Signature
class ExLlamaV2Module:
config: ExLlamaV2Config
key: str
alt_key: str | None
device_idx: int | list
footprint: int
submodules: list[ExLlamaV2Module]
assumed_footprint: int
def __init__(
self,
model: ExLlamaV2,
key: str,
archparams=None,
): ...
# Abstract interface
def numel(self) -> int: ...
def load(self, device_context: bool): ...
def unload(self): ...
def scratch_space_fixed(self) -> int: ...
def scratch_space_tp(self) -> int: ...
def scratch_space(self) -> int: ...
def forward(self, hidden_states, cache=None, attn_params=None,
past_len=None, intermediates=None, loras=None): ...
# Concrete methods
def device(self) -> str: ...
def load_multi(self, key, keys, measure=False, cpu=False) -> int | dict: ...
def load_weight(self, override_key=None, cpu=False): ...
def load_weight_fused(self, f_key, f_beg, f_end, in_feat, out_feat, altpack_qkv): ...
def weight_footprint(self) -> int: ...
def set_device_idx(self, idx: int | None): ...
def is_quant(self) -> bool: ...
def reload(self): ...
class Intervention(ExLlamaV2Module):
inner: ExLlamaV2Module
def __init__(
self,
inner: ExLlamaV2Module,
pre_forward=None,
post_forward=None
): ...
def forward(self, hidden_states, *args, **kwargs): ...
Import
from exllamav2.module import ExLlamaV2Module, Intervention
I/O Contract
Inputs (ExLlamaV2Module.__init__)
| Name |
Type |
Required |
Description
|
| model |
ExLlamaV2 |
Yes |
Parent model instance providing config, tensor_file_map, and device contexts
|
| key |
str |
Yes |
Weight namespace key (e.g. "model.layers.0.self_attn.q_proj")
|
| archparams |
object or None |
No |
Architecture parameters; defaults to model.config.arch.lm
|
Inputs (load_weight)
| Name |
Type |
Required |
Description
|
| override_key |
str or None |
No |
Alternative key to use instead of self.key for weight lookup
|
| cpu |
bool |
No (default False) |
If True, load weights to CPU instead of the module's device
|
Inputs (load_weight_fused)
| Name |
Type |
Required |
Description
|
| f_key |
str |
Yes |
Key of the fused (combined) weight tensor in the safetensors file
|
| f_beg |
int |
Yes |
Start row index for slicing the fused tensor
|
| f_end |
int |
Yes |
End row index for slicing the fused tensor
|
| in_feat |
int |
Yes |
Expected input feature dimension (for transposition detection)
|
| out_feat |
int |
Yes |
Expected output feature dimension (for transposition detection)
|
| altpack_qkv |
bool |
Yes |
If True, applies alternative QKV head interleaving before/after slicing
|
Inputs (Intervention.__init__)
| Name |
Type |
Required |
Description
|
| inner |
ExLlamaV2Module |
Yes |
The module to wrap
|
| pre_forward |
callable or None |
No |
Hook called before inner.forward(); receives (hidden_states, *args, **kwargs), returns modified hidden_states
|
| post_forward |
callable or None |
No |
Hook called after inner.forward(); receives (hidden_states, *args, **kwargs), returns modified hidden_states
|
Outputs
| Name |
Type |
Description
|
| forward() |
torch.Tensor |
Transformed hidden states after the module's computation
|
| load_weight() |
dict or tuple or nn.Parameter or None |
Loaded weight tensors; format depends on quantization type (EXL2 dict, GPTQ dict, torch Parameter, or None if key not found)
|
| load_weight_fused() |
nn.Parameter or tuple[nn.Parameter, nn.Parameter] or None |
Sliced weight (and optionally bias) from a fused tensor
|
| weight_footprint() |
int |
Total byte size of the module's weights on disk
|
| device() |
str |
Device string (e.g. "cuda:0" or "cpu")
|
| is_quant() |
bool |
Always returns False for the base class; subclasses override
|
Usage Examples
Subclassing ExLlamaV2Module
from exllamav2.module import ExLlamaV2Module
import torch
class MyCustomLayer(ExLlamaV2Module):
def __init__(self, model, key):
super().__init__(model, key)
# Initialize layer-specific attributes
def load(self):
# Load weights using the base class helpers
tensors = self.load_weight()
# ... set up parameters
def unload(self):
# Release weights
pass
def forward(self, hidden_states, cache=None, attn_params=None,
past_len=None, intermediates=None, loras=None):
# Implement forward computation
return hidden_states
def numel(self):
return 0
def scratch_space_fixed(self):
return 0
def scratch_space(self):
return 0
Using Intervention for Activation Probing
from exllamav2.module import Intervention
captured_activations = {}
def capture_post(hidden_states, *args, **kwargs):
captured_activations["layer_5"] = hidden_states.clone().cpu()
return hidden_states
# Wrap layer 5 with a post-forward hook
original_module = model.modules[5]
model.modules[5] = Intervention(
inner=original_module,
post_forward=capture_post
)
# Run a forward pass -- captured_activations["layer_5"] will be populated
output = model.forward(input_ids, cache=cache)
print(captured_activations["layer_5"].shape)
Using Intervention for Activation Steering
import torch
from exllamav2.module import Intervention
steering_vector = torch.randn(1, 1, 4096, device="cuda:0", dtype=torch.half) * 0.1
def steer_pre(hidden_states, *args, **kwargs):
return hidden_states + steering_vector
# Add a steering vector before layer 10
model.modules[10] = Intervention(
inner=model.modules[10],
pre_forward=steer_pre
)
Key Methods
ExLlamaV2Module
| Method |
Description
|
| __init__(model, key, archparams) |
Stores model reference, key, initializes footprint=-1 and empty submodules list
|
| device() |
Returns "cuda:{idx}" or "cpu" based on device_idx
|
| load_multi(key, keys, measure, cpu) |
Loads multiple named sub-tensors from safetensors files, grouped by file for I/O efficiency
|
| load_weight(override_key, cpu) |
Auto-detects quantization format (EXL2/GPTQ/torch) and loads the corresponding tensors
|
| load_weight_fused(f_key, f_beg, f_end, ...) |
Loads a row-slice from a fused weight file, handling transposition and QKV repacking
|
| weight_footprint() |
Returns total byte size of weights; caches the result after first computation
|
| set_device_idx(idx) |
Sets the target device index for weight loading
|
| is_quant() |
Returns False (base implementation); overridden by quantized subclasses
|
| reload() |
Convenience method that calls unload() then load()
|
Intervention
| Method |
Description
|
| __init__(inner, pre_forward, post_forward) |
Wraps the inner module, copies device_idx and padding attributes
|
| forward(hidden_states, *args, **kwargs) |
Calls pre_forward hook (if set), then inner.forward(), then post_forward hook (if set); ensures device consistency
|
| All other methods |
Delegated directly to inner module (load, unload, numel, weight_footprint, device, is_quant, etc.)
|
Related Pages